Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
169 changes: 169 additions & 0 deletions pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/config.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
/*
© 2025 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License.
*/

// Package requestcontrol contains helpers to decouple latency-predictor logic.
package slo_aware_router

import (
"os"
"strconv"
"strings"
)

var SLOBufferFactor = func() float64 {
if value, exists := os.LookupEnv("SLO_BUFFER_FACTOR"); exists {
if parsedValue, err := strconv.ParseFloat(value, 64); err == nil {
return parsedValue
}
}
return 1.0 // default value
}()

var NegHeadroomTTFTWeight = func() float64 {
if value, exists := os.LookupEnv("NEG_HEADROOM_TTFT_WEIGHT"); exists {
if parsedValue, err := strconv.ParseFloat(value, 64); err == nil && parsedValue >= 0 {
return parsedValue
}
}
return 0.8 // default: TTFT dominates when violating SLOs
}()

var NegHeadroomTPOTWeight = func() float64 {
if value, exists := os.LookupEnv("NEG_HEADROOM_TPOT_WEIGHT"); exists {
if parsedValue, err := strconv.ParseFloat(value, 64); err == nil && parsedValue >= 0 {
return parsedValue
}
}
return 0.2 // default: TPOT less important in your tiny-output scenario
}()

var HeadroomTTFTWeight = func() float64 {
if value, exists := os.LookupEnv("HEADROOM_TTFT_WEIGHT"); exists {
if parsedValue, err := strconv.ParseFloat(value, 64); err == nil && parsedValue >= 0 {
return parsedValue
}
}
return 0.8 // default
}()

var HeadroomTPOTWeight = func() float64 {
if value, exists := os.LookupEnv("HEADROOM_TPOT_WEIGHT"); exists {
if parsedValue, err := strconv.ParseFloat(value, 64); err == nil && parsedValue >= 0 {
return parsedValue
}
}
return 0.2 // default
}()

var HeadroomSelectionStrategy = func() HeadroomStrategy {
if value, exists := os.LookupEnv("HEADROOM_SELECTION_STRATEGY"); exists {
switch strings.ToLower(value) {
case "least":
return HeadroomStrategyLeast
case "most":
return HeadroomStrategyMost
case "composite-least":
return HeadroomStrategyCompositeLeast
case "composite-most":
return HeadroomStrategyCompositeMost
case "composite-only":
return HeadroomStrategyCompositeOnly
}
}
return HeadroomStrategyLeast // default to least (better packing)
}()

// If using composite headroom, weights for each component. Not used by default
var CompositeKVWeight = func() float64 {
if v, ok := os.LookupEnv("COMPOSITE_KV_WEIGHT"); ok {
if f, err := strconv.ParseFloat(v, 64); err == nil && f >= 0 {
return f
}
}
return 1
}()

var CompositeQueueWeight = func() float64 {
if v, ok := os.LookupEnv("COMPOSITE_QUEUE_WEIGHT"); ok {
if f, err := strconv.ParseFloat(v, 64); err == nil && f >= 0 {
return f
}
}
return 1
}()

var CompositePrefixWeight = func() float64 {
if v, ok := os.LookupEnv("COMPOSITE_PREFIX_WEIGHT"); ok {
if f, err := strconv.ParseFloat(v, 64); err == nil && f >= 0 {
return f
}
}
return 1
}()

// With probability ε, explore (ignore affinity gate); otherwise exploit.
var EpsilonExploreSticky = func() float64 {
// Prefer new env; fall back to old for compatibility.
if v, ok := os.LookupEnv("STICKY_EPSILON"); ok {
if f, err := strconv.ParseFloat(v, 64); err == nil && f >= 0 && f <= 1 {
return f
}
}
return 0.01 // default 1% exploration
}()

var EpsilonExploreNeg = func() float64 {
// Prefer new env; fall back to old for compatibility.
if v, ok := os.LookupEnv("NEG_HEADROOM_EPSILON"); ok {
if f, err := strconv.ParseFloat(v, 64); err == nil && f >= 0 && f <= 1 {
return f
}
}
return 0.01 // default 1% exploration
}()

// τ for per-path affinity gate (aka "stickiness" threshold).
var AffinityGateTau = func() float64 {
// Prefer new env; fall back to old for compatibility.
if v, ok := os.LookupEnv("AFFINITY_GATE_TAU"); ok {
if f, err := strconv.ParseFloat(v, 64); err == nil && f >= 0 && f <= 1 {
return f
}
}
return 0.80
}()

// Global τ for the overall candidate set (previously "overall stickiness").
var AffinityGateTauGlobal = func() float64 {
// Prefer new env; fall back to old for compatibility.
if v, ok := os.LookupEnv("AFFINITY_GATE_TAU_GLOBAL"); ok {
if f, err := strconv.ParseFloat(v, 64); err == nil && f >= 0 && f <= 1 {
return f
}
}
return 0.99
}()

// Read once at init. Values: "linear" (default) or "max".
var SelectionMode = func() PodSelectionMode {
if v, ok := os.LookupEnv("POD_SELECTION_MODE"); ok {
switch strings.ToLower(v) {
case "max":
return PodSelectionMax
case "linear":
fallthrough
default:
return PodSelectionLinear
}
}
return PodSelectionLinear
}()
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
/*
© 2025 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License.
*/

// Package requestcontrol contains helpers to decouple latency-predictor logic.
package slo_aware_router

import (
"fmt"
"strconv"

schedulingtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
errutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/error"
)

// parseFloatHeader retrieves a header by name, parses it as a float64,
// and returns the value or an error if the header is missing or invalid.
func parseFloatHeader(request schedulingtypes.LLMRequest, headerName string) (float64, bool, error) {
// 1. Get header value from the map
headerValue, ok := request.Headers[headerName]
if !ok {
return 0, false, nil // Header not found, return 0 and false
}

// 2. Parse the header value to a float64
parsedFloat, err := strconv.ParseFloat(headerValue, 64)
if err != nil {
return 0, false, errutil.Error{
Code: errutil.BadRequest,
Msg: fmt.Sprintf("%s must be a float", headerName),
}
}

// 3. Return the successfully parsed value
return parsedFloat, true, nil
}

// parseFloatHeader retrieves a header by name, parses it as a bool,
// and returns the value or an error if the header is missing or invalid.
func parseBoolHeader(request schedulingtypes.LLMRequest, headerName string) (bool, error) {
// 1. Get header value from the map
headerValue, ok := request.Headers[headerName]
if !ok {
return false, nil // Header not found, return 0 and false
}

// 2. Parse the header value to a bool
parsedBool, err := strconv.ParseBool(headerValue)
if err != nil {
return false, errutil.Error{
Code: errutil.BadRequest,
Msg: fmt.Sprintf("%s must be a bool", headerName),
}
}

// 3. Return the successfully parsed value
return parsedBool, nil
}
145 changes: 145 additions & 0 deletions pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/helpers.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
/*
Copyright 2025 The Kubernetes Authors.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package slo_aware_router

import (
"context"
"math"
"math/rand"

"sigs.k8s.io/controller-runtime/pkg/log"
schedulingtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"
)

func (s *SLOAwareRouter) selectFromCompositeScores(ctx context.Context, allPreds []PodPredictionResult, r *rand.Rand, strategy HeadroomStrategy) schedulingtypes.Pod {
total := 0
choices := s.buildCompositeChoices(
ctx, allPreds, CompositeKVWeight, CompositeQueueWeight, CompositePrefixWeight, &total,
)
if strategy == HeadroomStrategyCompositeLeast {
// Invert weights for "least" strategy
for i := range choices {
choices[i].Weight = minWeight + Wmax - choices[i].Weight
}
}
selectedPod := s.performWeightedRandomSelection(choices, total, allPreds, r)
return selectedPod
}
func (s *SLOAwareRouter) performWeightedRandomSelection(weightedChoices []Choice, total int, candidates []PodPredictionResult, r *rand.Rand) schedulingtypes.Pod {
if total == 0 {
return nil
}
logger := log.FromContext(context.Background())
// Check if MAX_SCORE_SELECTION env variable is set
if SelectionMode == PodSelectionMax {

logger.V(logutil.DEBUG).Info("Pod selection mode: MAX - selecting pod with highest weight")
maxWeight := 0
var selectedPod schedulingtypes.Pod
for _, c := range weightedChoices {
if c.Weight > maxWeight {
maxWeight = c.Weight
selectedPod = c.PodName
}
}
if selectedPod != nil {
return selectedPod
}
// Fallback to first pod if no selection made
return candidates[0].Pod
}

// Original weighted random selection logic
logger.V(logutil.DEBUG).Info("Pod selection mode: LINEAR - performing weighted random selection")
idx := r.Intn(total)
var selectedPod schedulingtypes.Pod

for _, c := range weightedChoices {
if idx < c.Weight {
selectedPod = c.PodName
break
}
idx -= c.Weight
}

// If no pod was selected (shouldn't happen), fallback to first pod
if selectedPod == nil {
selectedPod = candidates[0].Pod
}

return selectedPod
}
func (s *SLOAwareRouter) buildCompositeChoices(
ctx context.Context,
candidates []PodPredictionResult,
wkv, wq, wpref float64,
total *int,
) []Choice {

// Normalize weights
sumw := wkv + wq + wpref
if sumw <= 0 {
wkv, wq, wpref = 1, 0, 0
} else {
wkv /= sumw
wq /= sumw
wpref /= sumw
}

// Precompute queue stats
minQ, maxQ := math.MaxInt32, -1
queueCounts := make(map[string]int, len(candidates))
for _, p := range candidates {
q := p.Pod.GetMetrics().WaitingQueueSize
queueCounts[p.Pod.GetPod().String()] = q
if q < minQ {
minQ = q
}
if q > maxQ {
maxQ = q
}
}
den := float64(maxQ - minQ)

choices := make([]Choice, 0, len(candidates))
for _, p := range candidates {
q := queueCounts[p.Pod.GetPod().String()]
relQueue := 1.0
if den > 0 {
relQueue = (float64(maxQ-q) / den)
}

kvUsage := p.Pod.GetMetrics().KVCacheUsagePercent
kvFree := (1.0 - kvUsage)
prefix := (p.PrefixCacheScore)

composite := wkv*kvFree + wq*relQueue + wpref*prefix
w := int(math.Round(float64(minWeight) + (float64(Wmax-minWeight) * composite)))
*total += w
choices = append(choices, Choice{PodName: p.Pod, Weight: w})

log.FromContext(ctx).V(logutil.DEBUG).Info("Composite (neg/pos) score",
"pod", p.Pod.GetPod().String(),
"kvUsage", kvUsage, "kvFree", kvFree,
"queue", q, "relQueue", relQueue,
"prefix", prefix,
"wkv", wkv, "wq", wq, "wprefix", wpref,
"composite", composite, "weight", w)
}
return choices
}
Loading
Loading