Skip to content

Commit

Permalink
#45: Added weight field and WRR routing strategy
Browse files Browse the repository at this point in the history
  • Loading branch information
roma-glushko committed Jan 14, 2024
1 parent 01f46d2 commit b37972d
Show file tree
Hide file tree
Showing 9 changed files with 131 additions and 20 deletions.
4 changes: 3 additions & 1 deletion pkg/providers/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ type LangModelConfig struct {
ID string `yaml:"id" json:"id" validate:"required"` // Model instance ID (unique in scope of the router)
Enabled bool `yaml:"enabled" json:"enabled"` // Is the model enabled?
ErrorBudget health.ErrorBudget `yaml:"error_budget" json:"error_budget" swaggertype:"primitive,string"`
Weight int `yaml:"weight" json:"weight"`
Client *clients.ClientConfig `yaml:"client" json:"client"`
OpenAI *openai.Config `yaml:"openai" json:"openai"`
// Add other providers like
Expand All @@ -30,6 +31,7 @@ func DefaultLangModelConfig() *LangModelConfig {
Enabled: true,
Client: clients.DefaultClientConfig(),
ErrorBudget: health.DefaultErrorBudget(),
Weight: 1,
}
}

Expand All @@ -40,7 +42,7 @@ func (c *LangModelConfig) ToModel(tel *telemetry.Telemetry) (*LangModel, error)
return nil, fmt.Errorf("error initing openai client: %v", err)
}

return NewLangModel(c.ID, client, c.ErrorBudget), nil
return NewLangModel(c.ID, client, c.ErrorBudget, c.Weight), nil
}

return nil, ErrProviderNotFound
Expand Down
9 changes: 8 additions & 1 deletion pkg/providers/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ type LangModelProvider interface {
type Model interface {
ID() string
Healthy() bool
Weight() int
}

type LanguageModel interface {
Expand All @@ -29,14 +30,16 @@ type LanguageModel interface {
// LangModel
type LangModel struct {
modelID string
weight int
client LangModelProvider
rateLimit *health.RateLimitTracker
errorBudget *health.TokenBucket // TODO: centralize provider API health tracking in the registry
}

func NewLangModel(modelID string, client LangModelProvider, budget health.ErrorBudget) *LangModel {
func NewLangModel(modelID string, client LangModelProvider, budget health.ErrorBudget, weight int) *LangModel {
return &LangModel{
modelID: modelID,
weight: weight,
client: client,
rateLimit: health.NewRateLimitTracker(),
errorBudget: health.NewTokenBucket(budget.TimePerTokenMicro(), budget.Budget()),
Expand All @@ -55,6 +58,10 @@ func (m *LangModel) Healthy() bool {
return !m.rateLimit.Limited() && m.errorBudget.HasTokens()
}

func (m *LangModel) Weight() int {
return m.weight
}

func (m *LangModel) Chat(ctx context.Context, request *schemas.UnifiedChatRequest) (*schemas.UnifiedChatResponse, error) {
resp, err := m.client.Chat(ctx, request)

Expand Down
8 changes: 7 additions & 1 deletion pkg/providers/testing.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,14 @@ func (c *ProviderMock) Provider() string {
type LangModelMock struct {
modelID string
healthy bool
weight int
}

func NewLangModelMock(ID string, healthy bool) *LangModelMock {
func NewLangModelMock(ID string, healthy bool, weight int) *LangModelMock {
return &LangModelMock{
modelID: ID,
healthy: healthy,
weight: weight,
}
}

Expand All @@ -71,3 +73,7 @@ func (m *LangModelMock) ID() string {
func (m *LangModelMock) Healthy() bool {
return m.healthy
}

func (m *LangModelMock) Weight() int {
return m.weight
}
4 changes: 3 additions & 1 deletion pkg/routers/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,11 @@ func (c *LangRouterConfig) BuildRouting(models []providers.LanguageModel) (routi

switch c.RoutingStrategy {
case routing.Priority:
return routing.NewPriorityRouting(m), nil
return routing.NewPriority(m), nil
case routing.RoundRobin:
return routing.NewRoundRobinRouting(m), nil
case routing.WeightedRoundRobin:
return routing.NewWeightedRoundRobin(m), nil
}

return nil, fmt.Errorf("routing strategy \"%v\" is not supported, please make sure there is no typo", c.RoutingStrategy)
Expand Down
21 changes: 16 additions & 5 deletions pkg/routers/router_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,13 @@ func TestLangRouter_Priority_PickFistHealthy(t *testing.T) {
"first",
providers.NewProviderMock([]providers.ResponseMock{{Msg: "1"}, {Msg: "2"}}),
*budget,
1,
),
providers.NewLangModel(
"second",
providers.NewProviderMock([]providers.ResponseMock{{Msg: "1"}}),
*budget,
1,
),
}

Expand All @@ -40,7 +42,7 @@ func TestLangRouter_Priority_PickFistHealthy(t *testing.T) {
routerID: "test_router",
Config: &LangRouterConfig{},
retry: retry.NewExpRetry(3, 2, 1*time.Second, nil),
routing: routing.NewPriorityRouting(models),
routing: routing.NewPriority(models),
models: langModels,
telemetry: telemetry.NewTelemetryMock(),
}
Expand All @@ -64,16 +66,19 @@ func TestLangRouter_Priority_PickThirdHealthy(t *testing.T) {
"first",
providers.NewProviderMock([]providers.ResponseMock{{Err: &ErrNoModelAvailable}, {Msg: "3"}}),
*budget,
1,
),
providers.NewLangModel(
"second",
providers.NewProviderMock([]providers.ResponseMock{{Err: &ErrNoModelAvailable}, {Msg: "4"}}),
*budget,
1,
),
providers.NewLangModel(
"third",
providers.NewProviderMock([]providers.ResponseMock{{Msg: "1"}, {Msg: "2"}}),
*budget,
1,
),
}

Expand All @@ -88,7 +93,7 @@ func TestLangRouter_Priority_PickThirdHealthy(t *testing.T) {
routerID: "test_router",
Config: &LangRouterConfig{},
retry: retry.NewExpRetry(3, 2, 1*time.Second, nil),
routing: routing.NewPriorityRouting(models),
routing: routing.NewPriority(models),
models: langModels,
telemetry: telemetry.NewTelemetryMock(),
}
Expand All @@ -112,11 +117,13 @@ func TestLangRouter_Priority_SuccessOnRetry(t *testing.T) {
"first",
providers.NewProviderMock([]providers.ResponseMock{{Err: &ErrNoModelAvailable}, {Msg: "2"}}),
*budget,
1,
),
providers.NewLangModel(
"second",
providers.NewProviderMock([]providers.ResponseMock{{Err: &ErrNoModelAvailable}, {Msg: "1"}}),
*budget,
1,
),
}

Expand All @@ -129,7 +136,7 @@ func TestLangRouter_Priority_SuccessOnRetry(t *testing.T) {
routerID: "test_router",
Config: &LangRouterConfig{},
retry: retry.NewExpRetry(3, 2, 1*time.Millisecond, nil),
routing: routing.NewPriorityRouting(models),
routing: routing.NewPriority(models),
models: langModels,
telemetry: telemetry.NewTelemetryMock(),
}
Expand All @@ -148,11 +155,13 @@ func TestLangRouter_Priority_UnhealthyModelInThePool(t *testing.T) {
"first",
providers.NewProviderMock([]providers.ResponseMock{{Err: &clients.ErrProviderUnavailable}, {Msg: "3"}}),
*budget,
1,
),
providers.NewLangModel(
"second",
providers.NewProviderMock([]providers.ResponseMock{{Msg: "1"}, {Msg: "2"}}),
*budget,
1,
),
}

Expand All @@ -165,7 +174,7 @@ func TestLangRouter_Priority_UnhealthyModelInThePool(t *testing.T) {
routerID: "test_router",
Config: &LangRouterConfig{},
retry: retry.NewExpRetry(3, 2, 1*time.Millisecond, nil),
routing: routing.NewPriorityRouting(models),
routing: routing.NewPriority(models),
models: langModels,
telemetry: telemetry.NewTelemetryMock(),
}
Expand All @@ -186,11 +195,13 @@ func TestLangRouter_Priority_AllModelsUnavailable(t *testing.T) {
"first",
providers.NewProviderMock([]providers.ResponseMock{{Err: &ErrNoModelAvailable}, {Err: &ErrNoModelAvailable}}),
*budget,
1,
),
providers.NewLangModel(
"second",
providers.NewProviderMock([]providers.ResponseMock{{Err: &ErrNoModelAvailable}, {Err: &ErrNoModelAvailable}}),
*budget,
1,
),
}

Expand All @@ -203,7 +214,7 @@ func TestLangRouter_Priority_AllModelsUnavailable(t *testing.T) {
routerID: "test_router",
Config: &LangRouterConfig{},
retry: retry.NewExpRetry(1, 2, 1*time.Millisecond, nil),
routing: routing.NewPriorityRouting(models),
routing: routing.NewPriority(models),
models: langModels,
telemetry: telemetry.NewTelemetryMock(),
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/routers/routing/priority.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ type PriorityRouting struct {
models []providers.Model
}

func NewPriorityRouting(models []providers.Model) *PriorityRouting {
func NewPriority(models []providers.Model) *PriorityRouting {
return &PriorityRouting{
models: models,
}
Expand Down
12 changes: 6 additions & 6 deletions pkg/routers/routing/priority_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,10 @@ func TestPriorityRouting_PickModelsInOrder(t *testing.T) {
models := make([]providers.Model, 0, len(tc.models))

for _, model := range tc.models {
models = append(models, providers.NewLangModelMock(model.modelID, model.healthy))
models = append(models, providers.NewLangModelMock(model.modelID, model.healthy, 1))
}

routing := NewPriorityRouting(models)
routing := NewPriority(models)
iterator := routing.Iterator()

// loop three times over the whole pool to check if we return back to the begging of the list
Expand All @@ -47,12 +47,12 @@ func TestPriorityRouting_PickModelsInOrder(t *testing.T) {

func TestPriorityRouting_NoHealthyModels(t *testing.T) {
models := []providers.Model{
providers.NewLangModelMock("first", false),
providers.NewLangModelMock("second", false),
providers.NewLangModelMock("third", false),
providers.NewLangModelMock("first", false, 1),
providers.NewLangModelMock("second", false, 1),
providers.NewLangModelMock("third", false, 1),
}

routing := NewPriorityRouting(models)
routing := NewPriority(models)
iterator := routing.Iterator()

_, err := iterator.Next()
Expand Down
8 changes: 4 additions & 4 deletions pkg/routers/routing/round_robin_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ func TestRoundRobinRouting_PickModelsSequentially(t *testing.T) {
models := make([]providers.Model, 0, len(tc.models))

for _, model := range tc.models {
models = append(models, providers.NewLangModelMock(model.modelID, model.healthy))
models = append(models, providers.NewLangModelMock(model.modelID, model.healthy, 1))
}

routing := NewRoundRobinRouting(models)
Expand All @@ -50,9 +50,9 @@ func TestRoundRobinRouting_PickModelsSequentially(t *testing.T) {

func TestRoundRobinRouting_NoHealthyModels(t *testing.T) {
models := []providers.Model{
providers.NewLangModelMock("first", false),
providers.NewLangModelMock("second", false),
providers.NewLangModelMock("third", false),
providers.NewLangModelMock("first", false, 1),
providers.NewLangModelMock("second", false, 1),
providers.NewLangModelMock("third", false, 1),
}

routing := NewRoundRobinRouting(models)
Expand Down
83 changes: 83 additions & 0 deletions pkg/routers/routing/weighted_round_robin.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
package routing

import (
"sync"

"glide/pkg/providers"
)

const (
WeightedRoundRobin Strategy = "weighed-round-robin"
)

type Weighter struct {
model providers.Model
currentWeight int
}

func (w *Weighter) Current() int {
return w.currentWeight
}

func (w *Weighter) Weight() int {
return w.model.Weight()
}

func (w *Weighter) Incr() {
w.currentWeight += w.Weight()
}

func (w *Weighter) Decr(totalWeight int) {
w.currentWeight -= totalWeight
}

type WRoundRobinRouting struct {
mu sync.Mutex
weights []Weighter
}

func NewWeightedRoundRobin(models []providers.Model) *WRoundRobinRouting {
weights := make([]Weighter, 0, len(models))

return &WRoundRobinRouting{
weights: weights,
}
}

func (r *WRoundRobinRouting) Iterator() LangModelIterator {
return r
}

func (r *WRoundRobinRouting) Next() (providers.Model, error) {
r.mu.Lock()
defer r.mu.Unlock()

totalWeight := 0
var maxWeighter *Weighter

for _, weighter := range r.weights {
if !weighter.model.Healthy() {
continue
}

weighter.Incr()
totalWeight += weighter.Weight()

if maxWeighter == nil {
maxWeighter = &weighter
continue
}

if weighter.Current() > maxWeighter.Current() {
maxWeighter = &weighter
}
}

if maxWeighter != nil {
maxWeighter.Decr(totalWeight)

return maxWeighter.model, nil
}

return nil, ErrNoHealthyModels
}

0 comments on commit b37972d

Please sign in to comment.