diff --git a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/config.go b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/config.go new file mode 100644 index 000000000..cf9d8ee33 --- /dev/null +++ b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/config.go @@ -0,0 +1,169 @@ +/* +© 2025 The Kubernetes Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License. +*/ + +// Package requestcontrol contains helpers to decouple latency-predictor logic. +package slo_aware_router + +import ( + "os" + "strconv" + "strings" +) + +var SLOBufferFactor = func() float64 { + if value, exists := os.LookupEnv("SLO_BUFFER_FACTOR"); exists { + if parsedValue, err := strconv.ParseFloat(value, 64); err == nil { + return parsedValue + } + } + return 1.0 // default value +}() + +var NegHeadroomTTFTWeight = func() float64 { + if value, exists := os.LookupEnv("NEG_HEADROOM_TTFT_WEIGHT"); exists { + if parsedValue, err := strconv.ParseFloat(value, 64); err == nil && parsedValue >= 0 { + return parsedValue + } + } + return 0.8 // default: TTFT dominates when violating SLOs +}() + +var NegHeadroomTPOTWeight = func() float64 { + if value, exists := os.LookupEnv("NEG_HEADROOM_TPOT_WEIGHT"); exists { + if parsedValue, err := strconv.ParseFloat(value, 64); err == nil && parsedValue >= 0 { + return parsedValue + } + } + return 0.2 // default: TPOT less important in your tiny-output scenario +}() + +var HeadroomTTFTWeight = func() float64 { + if value, exists := os.LookupEnv("HEADROOM_TTFT_WEIGHT"); exists { + if parsedValue, err := strconv.ParseFloat(value, 64); err == nil && parsedValue >= 0 { + return parsedValue + } + } + return 0.8 // default +}() + +var HeadroomTPOTWeight = func() float64 { + if value, exists := os.LookupEnv("HEADROOM_TPOT_WEIGHT"); exists { + if parsedValue, err := strconv.ParseFloat(value, 64); err == nil && parsedValue >= 0 { + return parsedValue + } + } + return 0.2 // default +}() + +var HeadroomSelectionStrategy = func() HeadroomStrategy { + if value, exists := os.LookupEnv("HEADROOM_SELECTION_STRATEGY"); exists { + switch strings.ToLower(value) { + case "least": + return HeadroomStrategyLeast + case "most": + return HeadroomStrategyMost + case "composite-least": + return HeadroomStrategyCompositeLeast + case "composite-most": + return HeadroomStrategyCompositeMost + case "composite-only": + return HeadroomStrategyCompositeOnly + } + } + return HeadroomStrategyLeast // default to least (better packing) +}() + +// If using composite headroom, weights for each component. Not used by default +var CompositeKVWeight = func() float64 { + if v, ok := os.LookupEnv("COMPOSITE_KV_WEIGHT"); ok { + if f, err := strconv.ParseFloat(v, 64); err == nil && f >= 0 { + return f + } + } + return 1 +}() + +var CompositeQueueWeight = func() float64 { + if v, ok := os.LookupEnv("COMPOSITE_QUEUE_WEIGHT"); ok { + if f, err := strconv.ParseFloat(v, 64); err == nil && f >= 0 { + return f + } + } + return 1 +}() + +var CompositePrefixWeight = func() float64 { + if v, ok := os.LookupEnv("COMPOSITE_PREFIX_WEIGHT"); ok { + if f, err := strconv.ParseFloat(v, 64); err == nil && f >= 0 { + return f + } + } + return 1 +}() + +// With probability ε, explore (ignore affinity gate); otherwise exploit. +var EpsilonExploreSticky = func() float64 { + // Prefer new env; fall back to old for compatibility. + if v, ok := os.LookupEnv("STICKY_EPSILON"); ok { + if f, err := strconv.ParseFloat(v, 64); err == nil && f >= 0 && f <= 1 { + return f + } + } + return 0.01 // default 1% exploration +}() + +var EpsilonExploreNeg = func() float64 { + // Prefer new env; fall back to old for compatibility. + if v, ok := os.LookupEnv("NEG_HEADROOM_EPSILON"); ok { + if f, err := strconv.ParseFloat(v, 64); err == nil && f >= 0 && f <= 1 { + return f + } + } + return 0.01 // default 1% exploration +}() + +// τ for per-path affinity gate (aka "stickiness" threshold). +var AffinityGateTau = func() float64 { + // Prefer new env; fall back to old for compatibility. + if v, ok := os.LookupEnv("AFFINITY_GATE_TAU"); ok { + if f, err := strconv.ParseFloat(v, 64); err == nil && f >= 0 && f <= 1 { + return f + } + } + return 0.80 +}() + +// Global τ for the overall candidate set (previously "overall stickiness"). +var AffinityGateTauGlobal = func() float64 { + // Prefer new env; fall back to old for compatibility. + if v, ok := os.LookupEnv("AFFINITY_GATE_TAU_GLOBAL"); ok { + if f, err := strconv.ParseFloat(v, 64); err == nil && f >= 0 && f <= 1 { + return f + } + } + return 0.99 +}() + +// Read once at init. Values: "linear" (default) or "max". +var SelectionMode = func() PodSelectionMode { + if v, ok := os.LookupEnv("POD_SELECTION_MODE"); ok { + switch strings.ToLower(v) { + case "max": + return PodSelectionMax + case "linear": + fallthrough + default: + return PodSelectionLinear + } + } + return PodSelectionLinear +}() diff --git a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/headers.go b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/headers.go new file mode 100644 index 000000000..2588b3104 --- /dev/null +++ b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/headers.go @@ -0,0 +1,66 @@ +/* +© 2025 The Kubernetes Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License. +*/ + +// Package requestcontrol contains helpers to decouple latency-predictor logic. +package slo_aware_router + +import ( + "fmt" + "strconv" + + schedulingtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" + errutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/error" +) + +// parseFloatHeader retrieves a header by name, parses it as a float64, +// and returns the value or an error if the header is missing or invalid. +func parseFloatHeader(request schedulingtypes.LLMRequest, headerName string) (float64, bool, error) { + // 1. Get header value from the map + headerValue, ok := request.Headers[headerName] + if !ok { + return 0, false, nil // Header not found, return 0 and false + } + + // 2. Parse the header value to a float64 + parsedFloat, err := strconv.ParseFloat(headerValue, 64) + if err != nil { + return 0, false, errutil.Error{ + Code: errutil.BadRequest, + Msg: fmt.Sprintf("%s must be a float", headerName), + } + } + + // 3. Return the successfully parsed value + return parsedFloat, true, nil +} + +// parseFloatHeader retrieves a header by name, parses it as a bool, +// and returns the value or an error if the header is missing or invalid. +func parseBoolHeader(request schedulingtypes.LLMRequest, headerName string) (bool, error) { + // 1. Get header value from the map + headerValue, ok := request.Headers[headerName] + if !ok { + return false, nil // Header not found, return 0 and false + } + + // 2. Parse the header value to a bool + parsedBool, err := strconv.ParseBool(headerValue) + if err != nil { + return false, errutil.Error{ + Code: errutil.BadRequest, + Msg: fmt.Sprintf("%s must be a bool", headerName), + } + } + + // 3. Return the successfully parsed value + return parsedBool, nil +} diff --git a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/helpers.go b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/helpers.go new file mode 100644 index 000000000..3f02d5e52 --- /dev/null +++ b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/helpers.go @@ -0,0 +1,145 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package slo_aware_router + +import ( + "context" + "math" + "math/rand" + + "sigs.k8s.io/controller-runtime/pkg/log" + schedulingtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" + logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" +) + +func (s *SLOAwareRouter) selectFromCompositeScores(ctx context.Context, allPreds []PodPredictionResult, r *rand.Rand, strategy HeadroomStrategy) schedulingtypes.Pod { + total := 0 + choices := s.buildCompositeChoices( + ctx, allPreds, CompositeKVWeight, CompositeQueueWeight, CompositePrefixWeight, &total, + ) + if strategy == HeadroomStrategyCompositeLeast { + // Invert weights for "least" strategy + for i := range choices { + choices[i].Weight = minWeight + Wmax - choices[i].Weight + } + } + selectedPod := s.performWeightedRandomSelection(choices, total, allPreds, r) + return selectedPod +} +func (s *SLOAwareRouter) performWeightedRandomSelection(weightedChoices []Choice, total int, candidates []PodPredictionResult, r *rand.Rand) schedulingtypes.Pod { + if total == 0 { + return nil + } + logger := log.FromContext(context.Background()) + // Check if MAX_SCORE_SELECTION env variable is set + if SelectionMode == PodSelectionMax { + + logger.V(logutil.DEBUG).Info("Pod selection mode: MAX - selecting pod with highest weight") + maxWeight := 0 + var selectedPod schedulingtypes.Pod + for _, c := range weightedChoices { + if c.Weight > maxWeight { + maxWeight = c.Weight + selectedPod = c.PodName + } + } + if selectedPod != nil { + return selectedPod + } + // Fallback to first pod if no selection made + return candidates[0].Pod + } + + // Original weighted random selection logic + logger.V(logutil.DEBUG).Info("Pod selection mode: LINEAR - performing weighted random selection") + idx := r.Intn(total) + var selectedPod schedulingtypes.Pod + + for _, c := range weightedChoices { + if idx < c.Weight { + selectedPod = c.PodName + break + } + idx -= c.Weight + } + + // If no pod was selected (shouldn't happen), fallback to first pod + if selectedPod == nil { + selectedPod = candidates[0].Pod + } + + return selectedPod +} +func (s *SLOAwareRouter) buildCompositeChoices( + ctx context.Context, + candidates []PodPredictionResult, + wkv, wq, wpref float64, + total *int, +) []Choice { + + // Normalize weights + sumw := wkv + wq + wpref + if sumw <= 0 { + wkv, wq, wpref = 1, 0, 0 + } else { + wkv /= sumw + wq /= sumw + wpref /= sumw + } + + // Precompute queue stats + minQ, maxQ := math.MaxInt32, -1 + queueCounts := make(map[string]int, len(candidates)) + for _, p := range candidates { + q := p.Pod.GetMetrics().WaitingQueueSize + queueCounts[p.Pod.GetPod().String()] = q + if q < minQ { + minQ = q + } + if q > maxQ { + maxQ = q + } + } + den := float64(maxQ - minQ) + + choices := make([]Choice, 0, len(candidates)) + for _, p := range candidates { + q := queueCounts[p.Pod.GetPod().String()] + relQueue := 1.0 + if den > 0 { + relQueue = (float64(maxQ-q) / den) + } + + kvUsage := p.Pod.GetMetrics().KVCacheUsagePercent + kvFree := (1.0 - kvUsage) + prefix := (p.PrefixCacheScore) + + composite := wkv*kvFree + wq*relQueue + wpref*prefix + w := int(math.Round(float64(minWeight) + (float64(Wmax-minWeight) * composite))) + *total += w + choices = append(choices, Choice{PodName: p.Pod, Weight: w}) + + log.FromContext(ctx).V(logutil.DEBUG).Info("Composite (neg/pos) score", + "pod", p.Pod.GetPod().String(), + "kvUsage", kvUsage, "kvFree", kvFree, + "queue", q, "relQueue", relQueue, + "prefix", prefix, + "wkv", wkv, "wq", wq, "wprefix", wpref, + "composite", composite, "weight", w) + } + return choices +} diff --git a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/plugin.go b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/plugin.go deleted file mode 100644 index fe5a65e79..000000000 --- a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/plugin.go +++ /dev/null @@ -1,937 +0,0 @@ -/* -Copyright 2025 The Kubernetes Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package slo_aware_router - -import ( - "context" - "fmt" - "math" - "math/rand" - "os" - "strconv" - "strings" - "time" - - "sigs.k8s.io/controller-runtime/pkg/log" - - "k8s.io/apimachinery/pkg/types" - latencypredictor "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/latencypredictorasync" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework/plugins/multi/prefix" - schedulingtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" - errutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/error" - logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" -) - -// HeadroomStrategy defines how positive headroom pods should be weighted -type HeadroomStrategy string - -type Choice struct { - PodName schedulingtypes.Pod - Weight int -} - -const ( - // HeadroomStrategyLeast prioritizes pods with least positive headroom (better packing) - HeadroomStrategyLeast HeadroomStrategy = "least" - // HeadroomStrategyMost prioritizes pods with most positive headroom (more conservative) - HeadroomStrategyMost HeadroomStrategy = "most" - - // TTFT header string - TTFTSLOHeaderKey = "x-slo-ttft-ms" - // TPOT header string - TPOTSLOHeaderKey = "x-slo-tpot-ms" -) - -const ( - SLOAwareRouterPluginType = "slo-aware-routing" - MinScore = 0 - MaxScore = 100 -) - -var SLOBufferFactor = func() float64 { - if value, exists := os.LookupEnv("SLO_BUFFER_FACTOR"); exists { - if parsedValue, err := strconv.ParseFloat(value, 64); err == nil { - return parsedValue - } - } - return 1.0 // default value -}() - -var NegHeadroomTTFTWeight = func() float64 { - if value, exists := os.LookupEnv("NEG_HEADROOM_TTFT_WEIGHT"); exists { - if parsedValue, err := strconv.ParseFloat(value, 64); err == nil && parsedValue >= 0 { - return parsedValue - } - } - return 0.8 // default: TTFT dominates when violating SLOs -}() - -var NegHeadroomTPOTWeight = func() float64 { - if value, exists := os.LookupEnv("NEG_HEADROOM_TPOT_WEIGHT"); exists { - if parsedValue, err := strconv.ParseFloat(value, 64); err == nil && parsedValue >= 0 { - return parsedValue - } - } - return 0.2 // default: TPOT less important in your tiny-output scenario -}() - -var HeadroomTTFTWeight = func() float64 { - if value, exists := os.LookupEnv("HEADROOM_TTFT_WEIGHT"); exists { - if parsedValue, err := strconv.ParseFloat(value, 64); err == nil && parsedValue >= 0 { - return parsedValue - } - } - return 0.8 // default -}() - -var HeadroomTPOTWeight = func() float64 { - if value, exists := os.LookupEnv("HEADROOM_TPOT_WEIGHT"); exists { - if parsedValue, err := strconv.ParseFloat(value, 64); err == nil && parsedValue >= 0 { - return parsedValue - } - } - return 0.2 // default -}() - -var HeadroomSelectionStrategy = func() HeadroomStrategy { - if value, exists := os.LookupEnv("HEADROOM_SELECTION_STRATEGY"); exists { - switch strings.ToLower(value) { - case "least": - return HeadroomStrategyLeast - case "most": - return HeadroomStrategyMost - } - } - return HeadroomStrategyLeast // default to least (better packing) -}() - -// With probability ε, explore (ignore affinity gate); otherwise exploit. -var EpsilonExplore = func() float64 { - // Prefer new env; fall back to old for compatibility. - if v, ok := os.LookupEnv("STICKY_EPSILON"); ok { - if f, err := strconv.ParseFloat(v, 64); err == nil && f >= 0 && f <= 1 { - return f - } - } - return 0.01 // default 1% exploration -}() - -// τ for per-path affinity gate (aka "stickiness" threshold). -var AffinityGateTau = func() float64 { - // Prefer new env; fall back to old for compatibility. - if v, ok := os.LookupEnv("AFFINITY_GATE_TAU"); ok { - if f, err := strconv.ParseFloat(v, 64); err == nil && f >= 0 && f <= 1 { - return f - } - } - return 0.80 -}() - -// Global τ for the overall candidate set (previously "overall stickiness"). -var AffinityGateTauGlobal = func() float64 { - // Prefer new env; fall back to old for compatibility. - if v, ok := os.LookupEnv("AFFINITY_GATE_TAU_GLOBAL"); ok { - if f, err := strconv.ParseFloat(v, 64); err == nil && f >= 0 && f <= 1 { - return f - } - } - return 0.99 -}() - -// parseFloatHeader retrieves a header by name, parses it as a float64, -// and returns the value or an error if the header is missing or invalid. -func parseFloatHeader(request schedulingtypes.LLMRequest, headerName string) (float64, bool, error) { - // 1. Get header value from the map - headerValue, ok := request.Headers[headerName] - if !ok { - return 0, false, nil // Header not found, return 0 and false - } - - // 2. Parse the header value to a float64 - parsedFloat, err := strconv.ParseFloat(headerValue, 64) - if err != nil { - return 0, false, errutil.Error{ - Code: errutil.BadRequest, - Msg: fmt.Sprintf("%s must be a float", headerName), - } - } - - // 3. Return the successfully parsed value - return parsedFloat, true, nil -} - -// parseFloatHeader retrieves a header by name, parses it as a bool, -// and returns the value or an error if the header is missing or invalid. -func parseBoolHeader(request schedulingtypes.LLMRequest, headerName string) (bool, error) { - // 1. Get header value from the map - headerValue, ok := request.Headers[headerName] - if !ok { - return false, nil // Header not found, return 0 and false - } - - // 2. Parse the header value to a bool - parsedBool, err := strconv.ParseBool(headerValue) - if err != nil { - return false, errutil.Error{ - Code: errutil.BadRequest, - Msg: fmt.Sprintf("%s must be a bool", headerName), - } - } - - // 3. Return the successfully parsed value - return parsedBool, nil -} - -type PodPredictionResult struct { - Pod schedulingtypes.Pod - TTFT float64 - TPOT float64 - TTFTValid bool - TPOTValid bool - IsValid bool - Error error - Headroom float64 // Headroom for the pod, if applicable - TTFTHeadroom float64 // TTFT headroom for the pod - PrefixCacheScore float64 // Prefix cache score for the pod -} - -type SLOAwareRouter struct { - tn plugins.TypedName - latencypredictor latencypredictor.PredictorInterface - runningRequestLists map[types.NamespacedName]*RequestPriorityQueue - sloContextStore map[string]*SLORequestContext - headroomStrategy HeadroomStrategy -} - -func (s *SLOAwareRouter) Dependencies() []plugins.TypedName { - return []plugins.TypedName{ - {Type: "prefix-cache-scorer", Name: "prefix-cache-scorer"}, - } -} - -var _ framework.Scorer = &SLOAwareRouter{} - -func NewSLOAwareRouter(latencypredictor latencypredictor.PredictorInterface, strategy HeadroomStrategy) *SLOAwareRouter { - return &SLOAwareRouter{ - tn: plugins.TypedName{Type: SLOAwareRouterPluginType, Name: SLOAwareRouterPluginType}, - latencypredictor: latencypredictor, - runningRequestLists: make(map[types.NamespacedName]*RequestPriorityQueue), - sloContextStore: make(map[string]*SLORequestContext), - headroomStrategy: strategy, - } -} - -func (s *SLOAwareRouter) TypedName() plugins.TypedName { - return s.tn -} - -func (s *SLOAwareRouter) WithName(name string) *SLOAwareRouter { - s.tn.Name = name - return s -} - -// SetHeadroomStrategy allows runtime configuration of headroom selection strategy -func (s *SLOAwareRouter) SetHeadroomStrategy(strategy HeadroomStrategy) { - s.headroomStrategy = strategy -} - -// GetHeadroomStrategy returns the current headroom selection strategy -func (s *SLOAwareRouter) GetHeadroomStrategy() HeadroomStrategy { - return s.headroomStrategy -} - -func (s *SLOAwareRouter) epsilonGreedyAffinityGate( - ctx context.Context, - candidates []PodPredictionResult, - r *rand.Rand, - label string, // e.g. "positive" or "negative" - prefixStickyThreshold float64, -) ([]PodPredictionResult, bool) { - logger := log.FromContext(ctx) - - eligible := make([]PodPredictionResult, 0, len(candidates)) - for _, p := range candidates { - if p.PrefixCacheScore >= prefixStickyThreshold { - eligible = append(eligible, p) - } - } - - // No eligible sticky pods? Explore (no gating). - if len(eligible) == 0 { - return candidates, false - } - - // ε-exploration branch - if r.Float64() < EpsilonExplore { - logger.V(logutil.DEBUG).Info("ε-greedy: exploring (ignoring affinity gate)", - "path", label, "epsilon", EpsilonExplore, "eligibleCount", len(eligible)) - return candidates, false - } - - logger.V(logutil.DEBUG).Info("ε-greedy: exploiting (apply affinity gate)", - "path", label, "threshold", prefixStickyThreshold, "eligibleCount", len(eligible), "total", len(candidates)) - return eligible, true -} - -func (s *SLOAwareRouter) Score(ctx context.Context, state *schedulingtypes.CycleState, request *schedulingtypes.LLMRequest, pods []schedulingtypes.Pod) map[schedulingtypes.Pod]float64 { - logger := log.FromContext(ctx) - if s.latencypredictor == nil { - logger.V(logutil.DEBUG).Info("SLOAwareRouter: no predictor configured, returning nil scores") - return nil - } - - sloCtx := s.getOrMakeSLORequestContext(request) - - var err error - // get request slos - // Get Request SLOs from request header - sloCtx.TTFTSLO, _, err = parseFloatHeader(*request, TTFTSLOHeaderKey) - if err != nil { - logger.V(logutil.DEBUG).Error(errutil.Error{Code: errutil.BadRequest, Msg: fmt.Sprintf("%v must be a float: %v", TTFTSLOHeaderKey, err)}, "SLOAwareRouter: Error parsing TTFT SLO from header") - } - - sloCtx.AvgTPOTSLO, _, err = parseFloatHeader(*request, TPOTSLOHeaderKey) - if err != nil { - logger.V(logutil.DEBUG).Error(errutil.Error{Code: errutil.BadRequest, Msg: fmt.Sprintf("%v must be a float: %v", TPOTSLOHeaderKey, err)}, "SLOAwareRouter: Error parsing TPOT SLO from header") - } - sloCtx.PredictorBasedScheduling, err = parseBoolHeader(*request, "x-prediction-based-scheduling") - if err != nil { - logger.V(logutil.DEBUG).Error(errutil.Error{Code: errutil.BadRequest, Msg: fmt.Sprintf("x-prediction-based-scheduling must be a bool: %v", err)}, "SLOAwareRouter: Error parsing PredictorBasedScheduling from header") - } - - // Check if SLOs are provided - if !sloCtx.PredictorBasedScheduling { - logger.V(logutil.DEBUG).Info("PredictorBasedScheduling turned off, skipping prediction-based filtering") - return nil - } - - predictions := s.generatePredictions(ctx, state, request, sloCtx, pods) - s.updateRequestContextWithPredictions(sloCtx, predictions) - - allPreds := append([]PodPredictionResult(nil), predictions...) - - // Initialize scores map with all pods having score 0 - scores := make(map[schedulingtypes.Pod]float64, len(pods)) - for _, pod := range pods { - scores[pod] = 0 - } - - source := rand.NewSource(time.Now().UnixNano()) - r := rand.New(source) - allPreds, sticky := s.epsilonGreedyAffinityGate(ctx, allPreds, r, "overall", AffinityGateTauGlobal) - - // Check if all pods are invalid and all have running requests - allPodsInvalid := true - allPodsHaveRunningRequests := true - - for _, pred := range allPreds { - if pred.IsValid { - allPodsInvalid = false - } - - runningRequestCount := s.getPodRunningRequestCount(pred.Pod) - if runningRequestCount == 0 { - allPodsHaveRunningRequests = false - } - } - - // Set HasValidPod to false if all pods are invalid and all have running requests - if allPodsInvalid && allPodsHaveRunningRequests && !sticky { - sloCtx.HasValidPod = false - logger.V(logutil.DEBUG).Info("All pods are invalid and have running requests, setting HasValidPod to false") - } - - // 2) Tiered selection: positive headroom pods get 99% probability, negative get 1% - var posHeadroomPods, negHeadroomPods []PodPredictionResult - for _, p := range allPreds { - // A pod has positive headroom only if BOTH TTFT and TPOT have positive headroom - if p.Headroom > 0 && p.TTFTHeadroom > 0 { - posHeadroomPods = append(posHeadroomPods, p) - } else { - // A pod has negative headroom if EITHER TTFT or TPOT has negative/zero headroom - negHeadroomPods = append(negHeadroomPods, p) - } - } - - logger.V(logutil.DEBUG).Info("Pod headroom distribution", - "positivePods", len(posHeadroomPods), - "negativePods", len(negHeadroomPods)) - - var selectedPod schedulingtypes.Pod - - // If both positive and negative headroom pods exist, use tiered selection - if len(posHeadroomPods) > 0 && len(negHeadroomPods) > 0 { - // 99% chance to select from positive headroom pods, 1% from negative - if r.Float64() < EpsilonExplore { - logger.V(logutil.DEBUG).Info("Selecting from negative headroom pods (1% chance)") - selectedPod = s.selectFromNegativeHeadroomPods(ctx, negHeadroomPods, r) - } else { - logger.V(logutil.DEBUG).Info("Selecting from positive headroom pods (99% chance)") - selectedPod = s.selectFromPositiveHeadroomPods(ctx, posHeadroomPods, r) - } - } else if len(posHeadroomPods) > 0 { - // If only positive headroom pods exist, select from them - logger.V(logutil.DEBUG).Info("Only positive headroom pods available") - selectedPod = s.selectFromPositiveHeadroomPods(ctx, posHeadroomPods, r) - } else if len(negHeadroomPods) > 0 { - // If only negative headroom pods exist, select from them - logger.V(logutil.DEBUG).Info("Only negative headroom pods available") - selectedPod = s.selectFromNegativeHeadroomPods(ctx, negHeadroomPods, r) - } else if len(allPreds) > 0 { - // fallback - select randomly from valid pods - logger.V(logutil.DEBUG).Info("No headroom pods available, selecting randomly from valid pods") - selectedPod = allPreds[r.Intn(len(allPreds))].Pod - } else { - // No valid pods - return all zeros - logger.V(logutil.DEBUG).Info("No valid pods available, returning all zero scores") - return scores - } - - // Set score = 1 for selected pod, 0 for all others - if selectedPod != nil { - scores[selectedPod] = 1 - logger.V(logutil.DEBUG).Info("Selected pod for scheduling", "pod", selectedPod.GetPod().String()) - } - - s.setSLOContextForRequest(request, sloCtx) - - return scores -} - -func (t *SLOAwareRouter) getOrMakeSLORequestContext(request *schedulingtypes.LLMRequest) *SLORequestContext { - sloCtx, err := t.getSLOContextForRequest(request) - if err != nil { - sloCtx = NewSLORequestContext(request) - } - return sloCtx -} - -// selectFromPositiveHeadroomPods selects a pod from positive headroom pods using headroom strategy -// Updated to incorporate TTFTHeadroom with a configurable blend vs TPOT headroom. -func (s *SLOAwareRouter) selectFromPositiveHeadroomPods(ctx context.Context, posHeadroomPods []PodPredictionResult, r *rand.Rand) schedulingtypes.Pod { - logger := log.FromContext(ctx) - - if len(posHeadroomPods) == 1 { - return posHeadroomPods[0].Pod - } - - // Apply perfect stickiness (with exploration) - candidates, sticky := s.epsilonGreedyAffinityGate(ctx, posHeadroomPods, r, "positive", AffinityGateTau) - - // If perfect stickiness collapsed us to a single pod, short-circuit - if sticky && len(candidates) == 1 { - return candidates[0].Pod - } - const Wmax = 100 - const minWeight = 1 - const eps = 1e-9 - - // Find min/max for TPOT (Headroom) and TTFTHeadroom across positive pods to normalize to [0,1] - minTPOTH, maxTPOTH := math.MaxFloat64, -math.MaxFloat64 - minTTFTH, maxTTFTH := math.MaxFloat64, -math.MaxFloat64 - - for _, p := range candidates { - if p.Headroom < minTPOTH { - minTPOTH = p.Headroom - } - if p.Headroom > maxTPOTH { - maxTPOTH = p.Headroom - } - if p.TTFTHeadroom < minTTFTH { - minTTFTH = p.TTFTHeadroom - } - if p.TTFTHeadroom > maxTTFTH { - maxTTFTH = p.TTFTHeadroom - } - } - - tpotRange := maxTPOTH - minTPOTH - ttftRange := maxTTFTH - minTTFTH - - // Precompute blend weights (renormalize if user sets both to 0) - alpha := HeadroomTTFTWeight - beta := HeadroomTPOTWeight - if alpha+beta <= 0 { - alpha = 1.0 - beta = 0.0 - } - sum := alpha + beta - alpha /= sum - beta /= sum - - logger.V(logutil.DEBUG).Info("Positive headroom normalization ranges", - "minTPOTHeadroom", minTPOTH, "maxTPOTHeadroom", maxTPOTH, - "minTTFTHeadroom", minTTFTH, "maxTTFTHeadroom", maxTTFTH, - "alphaTTFT", alpha, "betaTPOT", beta, "strategy", s.headroomStrategy) - - // Calculate weights for weighted random selection - weightedChoices := make([]Choice, 0, len(candidates)) - total := 0 - - for _, p := range candidates { - // Normalize to [0,1] within the cohort - nTPOTH := 0.5 - if tpotRange > eps { - nTPOTH = (p.Headroom - minTPOTH) / (tpotRange + eps) - } - nTTFTH := 0.5 - if ttftRange > eps { - nTTFTH = (p.TTFTHeadroom - minTTFTH) / (ttftRange + eps) - } - - // Blend: larger combined -> "safer"; smaller -> "tighter packing" - combined := alpha*nTTFTH + beta*nTPOTH - - // Map to integer weights - var w int - switch s.headroomStrategy { - case HeadroomStrategyLeast: - // prefer smaller combined headroom (pack closer to limits) - w = int((1.0-combined)*float64(Wmax-minWeight)) + minWeight + 1 - case HeadroomStrategyMost: - // prefer larger combined headroom (more conservative / spread) - w = int(combined*float64(Wmax-minWeight)) + minWeight + 1 - default: - // Fallback to least - w = int((1.0-combined)*float64(Wmax-minWeight)) + minWeight + 1 - } - - weightedChoices = append(weightedChoices, Choice{PodName: p.Pod, Weight: w}) - total += w - - logger.V(logutil.TRACE).Info("Positive headroom blended weight", - "pod", p.Pod.GetPod().String(), - "ttftHeadroom", p.TTFTHeadroom, "normTTFTHeadroom", nTTFTH, - "tpotHeadroom", p.Headroom, "normTPOTHeadroom", nTPOTH, - "combined", combined, "weight", w) - } - - // Perform weighted random selection - idx := r.Intn(total) - var selectedPod schedulingtypes.Pod - - for _, c := range weightedChoices { - if idx < c.Weight { - selectedPod = c.PodName - break - } - idx -= c.Weight - } - - // If no pod was selected (shouldn't happen), fallback to first pod - if selectedPod == nil { - selectedPod = candidates[0].Pod - selectedPod = posHeadroomPods[0].Pod - } - - return selectedPod -} - -// selectFromNegativeHeadroomPods selects a pod from negative headroom pods using hierarchical TTFT/TPOT logic -// Modified to strictly prefer pods with 0 running requests -func (s *SLOAwareRouter) selectFromNegativeHeadroomPods(ctx context.Context, negHeadroomPods []PodPredictionResult, r *rand.Rand) schedulingtypes.Pod { - logger := log.FromContext(ctx) - - if len(negHeadroomPods) == 1 { - return negHeadroomPods[0].Pod - } - - // First, separate pods by running request count - var zeroRunningRequestPods, nonZeroRunningRequestPods []PodPredictionResult - - for _, p := range negHeadroomPods { - runningRequestCount := s.getPodRunningRequestCount(p.Pod) - if runningRequestCount == 0 { - zeroRunningRequestPods = append(zeroRunningRequestPods, p) - } else { - nonZeroRunningRequestPods = append(nonZeroRunningRequestPods, p) - } - } - - logger.V(logutil.DEBUG).Info("Negative headroom pods by running request count", - "zeroRunningRequests", len(zeroRunningRequestPods), - "nonZeroRunningRequests", len(nonZeroRunningRequestPods)) - - // If we have pods with 0 running requests, strictly prefer them - if len(zeroRunningRequestPods) > 0 { - logger.V(logutil.DEBUG).Info("Selecting from pods with zero running requests") - return s.selectFromNegativeHeadroomPodsInternal(ctx, zeroRunningRequestPods, r) - } - - // Otherwise, fall back to pods with running requests - logger.V(logutil.DEBUG).Info("No pods with zero running requests, selecting from pods with running requests") - return s.selectFromNegativeHeadroomPodsInternal(ctx, nonZeroRunningRequestPods, r) -} - -// selectFromNegativeHeadroomPodsInternal handles the actual selection logic for negative headroom pods -func (s *SLOAwareRouter) selectFromNegativeHeadroomPodsInternal(ctx context.Context, negHeadroomPods []PodPredictionResult, r *rand.Rand) schedulingtypes.Pod { - if len(negHeadroomPods) == 1 { - return negHeadroomPods[0].Pod - } - - // Apply perfect stickiness (with exploration) - candidates, sticky := s.epsilonGreedyAffinityGate(ctx, negHeadroomPods, r, "negative", AffinityGateTau) - - // If perfect stickiness collapsed us to a single pod, short-circuit - if sticky && len(candidates) == 1 { - return candidates[0].Pod - } - - const minWeightForNegative = 1 - - // Build weighted choices for selection - weightedChoices := make([]Choice, 0, len(candidates)) - total := 0 - - s.handleNegativeHeadroomPodsHierarchical(ctx, candidates, &weightedChoices, &total, minWeightForNegative) - - // Perform weighted random selection - idx := r.Intn(total) - var selectedPod schedulingtypes.Pod - - for _, c := range weightedChoices { - if idx < c.Weight { - selectedPod = c.PodName - break - } - idx -= c.Weight - } - - // If no pod was selected (shouldn't happen), fallback to first pod - if selectedPod == nil { - selectedPod = candidates[0].Pod - } - - return selectedPod -} - -// weightPodsByBlendedDeficit applies blended weighting using TTFT and TPOT deficits. -// Lower blended deficit => higher weight. -func (ps *SLOAwareRouter) weightPodsByBlendedDeficit( - ctx context.Context, - pods []PodPredictionResult, - choices *[]Choice, - total *int, - minWeight int, - alpha, beta float64, // weights for TTFT and TPOT deficits - category string, -) { - logger := log.FromContext(ctx) - if len(pods) == 0 { - return - } - - const Wrange = 80 - const eps = 1e-9 - - // Compute raw deficits (only when headroom is negative) - type deficits struct { - pod PodPredictionResult - ttftDef float64 - tpotDef float64 - } - defs := make([]deficits, 0, len(pods)) - - minTTFT, maxTTFT := math.MaxFloat64, -math.MaxFloat64 - minTPOT, maxTPOT := math.MaxFloat64, -math.MaxFloat64 - - for _, p := range pods { - ttftDef := 0.0 - if p.TTFTHeadroom < 0 { - ttftDef = -p.TTFTHeadroom - } - tpotDef := 0.0 - if p.Headroom < 0 { - tpotDef = -p.Headroom - } - defs = append(defs, deficits{pod: p, ttftDef: ttftDef, tpotDef: tpotDef}) - - if ttftDef < minTTFT { - minTTFT = ttftDef - } - if ttftDef > maxTTFT { - maxTTFT = ttftDef - } - if tpotDef < minTPOT { - minTPOT = tpotDef - } - if tpotDef > maxTPOT { - maxTPOT = tpotDef - } - } - - ttftRange := maxTTFT - minTTFT - tpotRange := maxTPOT - minTPOT - - // Normalize alpha/beta - if alpha+beta <= 0 { - alpha, beta = 1.0, 0.0 - } else { - sum := alpha + beta - alpha /= sum - beta /= sum - } - - logger.V(logutil.DEBUG).Info("Negative headroom blended deficits", - "category", category, - "minTTFTDef", minTTFT, "maxTTFTDef", maxTTFT, - "minTPOTDef", minTPOT, "maxTPOTDef", maxTPOT, - "alphaTTFT", alpha, "betaTPOT", beta, "podCount", len(pods)) - - for _, d := range defs { - // Normalize deficits to [0,1] within this bucket (0 = best / least violation) - nTTFT := 0.0 - if ttftRange > eps { - nTTFT = (d.ttftDef - minTTFT) / (ttftRange + eps) - } - nTPOT := 0.0 - if tpotRange > eps { - nTPOT = (d.tpotDef - minTPOT) / (tpotRange + eps) - } - - // Blended "badness": higher = worse violation - blended := alpha*nTTFT + beta*nTPOT - - // Convert to selection weight: lower badness -> higher weight - // Ensure a floor so no pod is completely excluded within the bucket. - w := int((1.0-blended)*float64(Wrange)) + minWeight + 1 - - *choices = append(*choices, Choice{PodName: d.pod.Pod, Weight: w}) - *total += w - - logger.V(logutil.TRACE).Info("Negative bucket blended weighting", - "pod", d.pod.Pod.GetPod().String(), - "ttftDef", d.ttftDef, "tpotDef", d.tpotDef, - "normTTFT", nTTFT, "normTPOT", nTPOT, - "blendedBadness", blended, "weight", w) - } -} - -func (s *SLOAwareRouter) handleNegativeHeadroomPodsHierarchical( - ctx context.Context, - negHeadroomPods []PodPredictionResult, - choices *[]Choice, - total *int, - minWeightForNegative int, -) { - logger := log.FromContext(ctx) - - // Categorize pods by their headroom status - var negTTFTNegTPOT, negTTFTNonNegTPOT, nonNegTTFTNegTPOT, nonNegTTFTNonNegTPOT []PodPredictionResult - - for _, p := range negHeadroomPods { - if p.TTFTHeadroom < 0 && p.Headroom < 0 { - negTTFTNegTPOT = append(negTTFTNegTPOT, p) - } else if p.TTFTHeadroom < 0 && p.Headroom >= 0 { - negTTFTNonNegTPOT = append(negTTFTNonNegTPOT, p) - } else if p.TTFTHeadroom >= 0 && p.Headroom < 0 { - nonNegTTFTNegTPOT = append(nonNegTTFTNegTPOT, p) - } else { - nonNegTTFTNonNegTPOT = append(nonNegTTFTNonNegTPOT, p) - } - } - - logger.V(logutil.DEBUG).Info("Hierarchical negative headroom pod distribution", - "totalNegative", len(negHeadroomPods), - "negTTFT_negTPOT", len(negTTFTNegTPOT), - "negTTFT_nonNegTPOT", len(negTTFTNonNegTPOT), - "nonNegTTFT_negTPOT", len(nonNegTTFTNegTPOT), - "nonNegTTFT_nonNegTPOT", len(nonNegTTFTNonNegTPOT)) - - // Priority 1: both TTFT and TPOT negative -> blended deficits (both active) - if len(negTTFTNegTPOT) > 0 { - s.weightPodsByBlendedDeficit(ctx, negTTFTNegTPOT, choices, total, minWeightForNegative, - NegHeadroomTTFTWeight, NegHeadroomTPOTWeight, "both_negative") - } - - // Priority 2: TTFT negative, TPOT non-negative -> blended still works (TPOT deficit=0) - if len(negTTFTNonNegTPOT) > 0 { - s.weightPodsByBlendedDeficit(ctx, negTTFTNonNegTPOT, choices, total, minWeightForNegative, - NegHeadroomTTFTWeight, NegHeadroomTPOTWeight, "ttft_negative") - } - - // Priority 3: TTFT non-negative, TPOT negative -> blended (TTFT deficit=0) - if len(nonNegTTFTNegTPOT) > 0 { - s.weightPodsByBlendedDeficit(ctx, nonNegTTFTNegTPOT, choices, total, minWeightForNegative, - NegHeadroomTTFTWeight, NegHeadroomTPOTWeight, "tpot_negative") - } - - // Priority 4: edge-case bucket -> minimal weight - for _, p := range nonNegTTFTNonNegTPOT { - *choices = append(*choices, Choice{PodName: p.Pod, Weight: minWeightForNegative}) - *total += minWeightForNegative - } -} - -// generatePredictions creates prediction results for all candidate pods -func (s *SLOAwareRouter) generatePredictions(ctx context.Context, state *schedulingtypes.CycleState, request *schedulingtypes.LLMRequest, sloCtx *SLORequestContext, candidatePods []schedulingtypes.Pod) []PodPredictionResult { - logger := log.FromContext(ctx) - predictions := make([]PodPredictionResult, 0, len(candidatePods)) - - for _, pod := range candidatePods { - predResult := PodPredictionResult{Pod: pod} - - logger.V(logutil.TRACE).Info("Candidate pod for scheduling", "pod", pod.GetPod().String(), "metrics", pod.GetMetrics().String()) - - // Get prefix cache score for the pod - prefixCacheScore := s.getPrefixCacheScoreForPod(ctx, state, pod) - - sloCtx.PrefixCacheScoresForPods[pod.GetPod().String()] = prefixCacheScore - - logger.V(logutil.DEBUG).Info("Prefix cache score for pod", "pod", pod.GetPod().String(), "prefixCacheScore", prefixCacheScore) - - // Generate prediction - prediction, err := PredictWithMetrics(ctx, s.latencypredictor, pod.GetMetrics(), request.Body.Completions.Prompt, 1, prefixCacheScore) - if err != nil { - logger.V(logutil.DEBUG).Info("Skipping pod due to prediction error", "pod", pod.GetPod().String(), "error", err) - predResult.Error = err - predictions = append(predictions, predResult) - continue - } - predResult.PrefixCacheScore = prefixCacheScore - predResult.TTFT = prediction.TTFT - predResult.TPOT = prediction.TPOT - podMinTPOTSLO := 0.0 - //if pod.GetPod().RunningRequests.Peek() != nil { - // podMinTPOTSLO = pod.GetPod().RunningRequests.Peek().TPOT - //} - // Do this: - podMinTPOTSLO = s.getPodMinTPOTSLO(pod) - predResult.TTFTValid, predResult.TPOTValid, predResult.IsValid, predResult.Headroom, predResult.TTFTHeadroom = s.validatePrediction(prediction, sloCtx, podMinTPOTSLO) - - logger.V(logutil.DEBUG).Info("Prediction for scheduling", - "pod", pod.GetPod().String(), - "prefixCacheScore", prefixCacheScore, - "TTFT", prediction.TTFT, - "TPOT", prediction.TPOT, - "buffer", SLOBufferFactor, - "podMinTPOTSLO", podMinTPOTSLO, - "ttftSLO", sloCtx.TTFTSLO, - "requestTPOTSLO", sloCtx.AvgTPOTSLO, - "tpotHeadroom", predResult.Headroom, - "ttftHeadroom", predResult.TTFTHeadroom, - "tpotValid", predResult.TPOTValid, - "ttftValid", predResult.TTFTValid, - "headroomStrategy", s.headroomStrategy) - - predictions = append(predictions, predResult) - } - - return predictions -} - -func (s *SLOAwareRouter) getPodMinTPOTSLO(pod schedulingtypes.Pod) float64 { - podName := types.NamespacedName{ - Name: pod.GetPod().NamespacedName.Name, - Namespace: pod.GetPod().NamespacedName.Namespace, - } - if runningReqs, ok := s.runningRequestLists[podName]; ok && runningReqs.GetSize() > 0 { - if topReq := runningReqs.Peek(); topReq != nil { - return topReq.TPOT - } - } - return 0 // no running requests or no TPOT SLOs -} - -func (s *SLOAwareRouter) getPodRunningRequestCount(pod schedulingtypes.Pod) int { - podName := types.NamespacedName{ - Name: pod.GetPod().NamespacedName.Name, - Namespace: pod.GetPod().NamespacedName.Namespace, - } - if runningReqs, ok := s.runningRequestLists[podName]; ok { - return runningReqs.GetSize() - } - return 0 // no running requests -} - -func (s *SLOAwareRouter) validatePrediction( - pred *latencypredictor.PredictionResponse, - sloCtx *SLORequestContext, - podMinTPOTSLO float64, -) (ttftOk, tpotOk, isValid bool, headroom float64, ttftHeadroom float64) { - - bufferedTPOT := sloCtx.AvgTPOTSLO * SLOBufferFactor - // a podMinTPOTSLO of 0 means no either no requests, or no TPOT SLOs specified on running requests - if podMinTPOTSLO > 0 { - if podMinTPOTSLO < sloCtx.AvgTPOTSLO { - //print debug message - log.FromContext(context.Background()).V(logutil.DEBUG).Info("Pod min TPOT SLO is less than the req SLO, adjusting", "podMinTPOTSLO", podMinTPOTSLO, "bufferedTPOT", sloCtx.AvgTPOTSLO) - } - bufferedTPOT = min(bufferedTPOT, podMinTPOTSLO*SLOBufferFactor) - } - - tpotOk = pred.TPOT < bufferedTPOT - ttftOk = pred.TTFT < sloCtx.TTFTSLO - - isValid = ttftOk && tpotOk - headroom = bufferedTPOT - pred.TPOT - ttftHeadroom = sloCtx.TTFTSLO - pred.TTFT - return -} - -func (s *SLOAwareRouter) getPrefixCacheScoreForPod(ctx context.Context, cycleState *schedulingtypes.CycleState, pod schedulingtypes.Pod) float64 { - log.FromContext(ctx).V(logutil.DEBUG).Info("Running getPrefixCacheScoreForPod, getting prefix cache score for pod", "pod", pod.GetPod().String()) - plugintype := prefix.PrefixCachePluginType - pluginname := prefix.PrefixCachePluginType - cycleStateKey := (plugins.TypedName{Type: plugintype, Name: pluginname}).String() - stateData, err := cycleState.Read(plugins.StateKey(cycleStateKey)) - - log.FromContext(ctx).V(logutil.DEBUG).Info("Reading prefix cache state from cycle state", "stateKey", cycleStateKey) - - if err != nil { - // The prefix cache plugin might not be enabled, which is a valid scenario. - log.FromContext(ctx).V(logutil.DEBUG).Info("Prefix cache state not found in cycle state, returning prefix cache score of 0.0", "pod", pod.GetPod().String()) - return 0.0 - } - - prefixCacheState, ok := stateData.(*prefix.SchedulingContextState) - if !ok { - // This should not happen if the plugin is configured correctly. - log.FromContext(ctx).Error(fmt.Errorf("unexpected state type: %T", stateData), "failed to read prefix cache state") - return 0.0 - } - - total := len(prefixCacheState.PrefixHashes) - if total == 0 { - // if the request has no prefixes, return 0.0 - log.FromContext(ctx).V(logutil.DEBUG).Info("No prefixes found in request, returning prefix cache score of 0.0") - return 0.0 - } - - matchLen := prefixCacheState.PrefixCacheServers[prefix.ServerID(pod.GetPod().NamespacedName)] - log.FromContext(ctx).V(logutil.DEBUG).Info("Prefix cache score for pod", "pod", pod.GetPod().String(), "matchLen", matchLen, "totalPrefixes", total) - return float64(matchLen) / float64(total) -} - -// updateRequestContextWithPredictions updates the request context with prediction data -func (s *SLOAwareRouter) updateRequestContextWithPredictions(sloCtx *SLORequestContext, predictions []PodPredictionResult) { - for _, pred := range predictions { - if pred.Error == nil { - podKey := pred.Pod.GetPod().String() - if sloCtx.PredictedTTFTForScheduling == nil { - sloCtx.PredictedTTFTForScheduling = make(map[string]float64) - } - if sloCtx.PredictedTPOTForScheduling == nil { - sloCtx.PredictedTPOTForScheduling = make(map[string]float64) - } - sloCtx.PredictedTTFTForScheduling[podKey] = pred.TTFT - sloCtx.PredictedTPOTForScheduling[podKey] = pred.TPOT - } - } -} diff --git a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/prediction.go b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/prediction.go new file mode 100644 index 000000000..2645d112a --- /dev/null +++ b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/prediction.go @@ -0,0 +1,122 @@ +/* +© 2025 The Kubernetes Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License. +*/ + +// Package requestcontrol contains helpers to decouple latency-predictor logic. +package slo_aware_router + +import ( + "context" + + "sigs.k8s.io/controller-runtime/pkg/log" + latencypredictor "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/latencypredictorasync" + schedulingtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" + logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" +) + +// generatePredictions creates prediction results for all candidate pods +func (s *SLOAwareRouter) generatePredictions(ctx context.Context, state *schedulingtypes.CycleState, request *schedulingtypes.LLMRequest, sloCtx *SLORequestContext, candidatePods []schedulingtypes.Pod) []PodPredictionResult { + logger := log.FromContext(ctx) + predictions := make([]PodPredictionResult, 0, len(candidatePods)) + + for _, pod := range candidatePods { + predResult := PodPredictionResult{Pod: pod} + + logger.V(logutil.TRACE).Info("Candidate pod for scheduling", "pod", pod.GetPod().String(), "metrics", pod.GetMetrics().String()) + + // Get prefix cache score for the pod + prefixCacheScore := s.getPrefixCacheScoreForPod(ctx, state, pod) + + sloCtx.PrefixCacheScoresForPods[pod.GetPod().String()] = prefixCacheScore + + logger.V(logutil.DEBUG).Info("Prefix cache score for pod", "pod", pod.GetPod().String(), "prefixCacheScore", prefixCacheScore) + + // Generate prediction + prediction, err := PredictWithMetrics(ctx, s.latencypredictor, pod.GetMetrics(), request.Body.Completions.Prompt, 1, prefixCacheScore) + if err != nil { + logger.V(logutil.DEBUG).Info("Skipping pod due to prediction error", "pod", pod.GetPod().String(), "error", err) + predResult.Error = err + predictions = append(predictions, predResult) + continue + } + predResult.PrefixCacheScore = prefixCacheScore + predResult.TTFT = prediction.TTFT + predResult.TPOT = prediction.TPOT + podMinTPOTSLO := 0.0 + //if pod.GetPod().RunningRequests.Peek() != nil { + // podMinTPOTSLO = pod.GetPod().RunningRequests.Peek().TPOT + //} + // Do this: + podMinTPOTSLO = s.getPodMinTPOTSLO(pod) + predResult.TTFTValid, predResult.TPOTValid, predResult.IsValid, predResult.Headroom, predResult.TTFTHeadroom = s.validatePrediction(prediction, sloCtx, podMinTPOTSLO) + + logger.V(logutil.DEBUG).Info("Prediction for scheduling", + "pod", pod.GetPod().String(), + "prefixCacheScore", prefixCacheScore, + "TTFT", prediction.TTFT, + "TPOT", prediction.TPOT, + "buffer", SLOBufferFactor, + "podMinTPOTSLO", podMinTPOTSLO, + "ttftSLO", sloCtx.TTFTSLO, + "requestTPOTSLO", sloCtx.AvgTPOTSLO, + "tpotHeadroom", predResult.Headroom, + "ttftHeadroom", predResult.TTFTHeadroom, + "tpotValid", predResult.TPOTValid, + "ttftValid", predResult.TTFTValid, + "headroomStrategy", s.headroomStrategy) + + predictions = append(predictions, predResult) + } + + return predictions +} + +// updateRequestContextWithPredictions updates the request context with prediction data +func (s *SLOAwareRouter) updateRequestContextWithPredictions(sloCtx *SLORequestContext, predictions []PodPredictionResult) { + for _, pred := range predictions { + if pred.Error == nil { + podKey := pred.Pod.GetPod().String() + if sloCtx.PredictedTTFTForScheduling == nil { + sloCtx.PredictedTTFTForScheduling = make(map[string]float64) + } + if sloCtx.PredictedTPOTForScheduling == nil { + sloCtx.PredictedTPOTForScheduling = make(map[string]float64) + } + sloCtx.PredictedTTFTForScheduling[podKey] = pred.TTFT + sloCtx.PredictedTPOTForScheduling[podKey] = pred.TPOT + } + } +} + +func (s *SLOAwareRouter) validatePrediction( + pred *latencypredictor.PredictionResponse, + sloCtx *SLORequestContext, + podMinTPOTSLO float64, +) (ttftOk, tpotOk, isValid bool, headroom float64, ttftHeadroom float64) { + + bufferedTPOT := sloCtx.AvgTPOTSLO * SLOBufferFactor + // a podMinTPOTSLO of 0 means no either no requests, or no TPOT SLOs specified on running requests + if podMinTPOTSLO > 0 { + if podMinTPOTSLO < sloCtx.AvgTPOTSLO { + //print debug message + log.FromContext(context.Background()).V(logutil.DEBUG).Info("Pod min TPOT SLO is less than the req SLO, adjusting", "podMinTPOTSLO", podMinTPOTSLO, "bufferedTPOT", sloCtx.AvgTPOTSLO) + } + bufferedTPOT = min(bufferedTPOT, podMinTPOTSLO*SLOBufferFactor) + } + + tpotOk = pred.TPOT < bufferedTPOT + ttftOk = pred.TTFT < sloCtx.TTFTSLO + + isValid = ttftOk && tpotOk + headroom = bufferedTPOT - pred.TPOT + ttftHeadroom = sloCtx.TTFTSLO - pred.TTFT + return +} diff --git a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/requestcontrol_hooks.go b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/requestcontrol_hooks.go index dbe7b46f4..17399c6a8 100644 --- a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/requestcontrol_hooks.go +++ b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/requestcontrol_hooks.go @@ -61,27 +61,30 @@ type SLORequestContext struct { func NewSLORequestContext(request *schedulingtypes.LLMRequest) *SLORequestContext { return &SLORequestContext{ - SchedulingRequest: *request, - LastSeenMetrics: make(map[string]*backendmetrics.MetricsState), + SchedulingRequest: *request, + LastSeenMetrics: make(map[string]*backendmetrics.MetricsState), + PrefixCacheScoresForPods: make(map[string]float64), + PredictedTTFTForScheduling: make(map[string]float64), + PredictedTPOTForScheduling: make(map[string]float64), } } func (s *SLOAwareRouter) getSLOContextForRequest(request *schedulingtypes.LLMRequest) (*SLORequestContext, error) { id := request.Headers[requtil.RequestIdHeaderKey] - if ctx, exists := s.sloContextStore[id]; exists { - return ctx, nil + if ctx, exists := s.sloContextStore.Load(id); exists { + return ctx.(*SLORequestContext), nil } return nil, fmt.Errorf("SLO context not found for request ID: %s", id) } func (s *SLOAwareRouter) setSLOContextForRequest(request *schedulingtypes.LLMRequest, ctx *SLORequestContext) { id := request.Headers[requtil.RequestIdHeaderKey] - s.sloContextStore[id] = ctx + s.sloContextStore.Store(id, ctx) } func (s *SLOAwareRouter) deleteSLOContextForRequest(request *schedulingtypes.LLMRequest) { id := request.Headers[requtil.RequestIdHeaderKey] - delete(s.sloContextStore, id) + s.sloContextStore.Delete(id) } // --- RequestControl Hooks --- @@ -130,6 +133,7 @@ func (t *SLOAwareRouter) PreRequest(ctx context.Context, request *schedulingtype // Set up SLO request context sloCtx.TargetPod = targetPod sloCtx.SchedulingResult = schedulingResult + sloCtx.RequestReceivedTimestamp = time.Now() RefreshLastSeenMetrics(ctx, sloCtx) t.setSLOContextForRequest(request, sloCtx) } diff --git a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/scorer.go b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/scorer.go new file mode 100644 index 000000000..9ecb396af --- /dev/null +++ b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/scorer.go @@ -0,0 +1,296 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package slo_aware_router + +import ( + "context" + "fmt" + "math/rand" + "sync" + "time" + + "sigs.k8s.io/controller-runtime/pkg/log" + + "k8s.io/apimachinery/pkg/types" + latencypredictor "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/latencypredictorasync" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework/plugins/multi/prefix" + schedulingtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" + errutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/error" + logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" +) + +type PodPredictionResult struct { + Pod schedulingtypes.Pod + TTFT float64 + TPOT float64 + TTFTValid bool + TPOTValid bool + IsValid bool + Error error + Headroom float64 // Headroom for the pod, if applicable + TTFTHeadroom float64 // TTFT headroom for the pod + PrefixCacheScore float64 // Prefix cache score for the pod +} + +type SLOAwareRouter struct { + tn plugins.TypedName + latencypredictor latencypredictor.PredictorInterface + runningRequestLists map[types.NamespacedName]*RequestPriorityQueue + sloContextStore sync.Map // map[string]*SLORequestContext + headroomStrategy HeadroomStrategy +} + +func (s *SLOAwareRouter) Dependencies() []plugins.TypedName { + return []plugins.TypedName{ + {Type: "prefix-cache-scorer", Name: "prefix-cache-scorer"}, + } +} + +var _ framework.Scorer = &SLOAwareRouter{} + +func NewSLOAwareRouter(latencypredictor latencypredictor.PredictorInterface, strategy HeadroomStrategy) *SLOAwareRouter { + return &SLOAwareRouter{ + tn: plugins.TypedName{Type: SLOAwareRouterPluginType, Name: SLOAwareRouterPluginType}, + latencypredictor: latencypredictor, + runningRequestLists: make(map[types.NamespacedName]*RequestPriorityQueue), + sloContextStore: sync.Map{}, + headroomStrategy: strategy, + } +} + +func (s *SLOAwareRouter) TypedName() plugins.TypedName { + return s.tn +} + +func (s *SLOAwareRouter) WithName(name string) *SLOAwareRouter { + s.tn.Name = name + return s +} + +// SetHeadroomStrategy allows runtime configuration of headroom selection strategy +func (s *SLOAwareRouter) SetHeadroomStrategy(strategy HeadroomStrategy) { + s.headroomStrategy = strategy +} + +// GetHeadroomStrategy returns the current headroom selection strategy +func (s *SLOAwareRouter) GetHeadroomStrategy() HeadroomStrategy { + return s.headroomStrategy +} + +func (s *SLOAwareRouter) epsilonGreedyAffinityGate( + ctx context.Context, + candidates []PodPredictionResult, + r *rand.Rand, + label string, // e.g. "positive" or "negative" + prefixStickyThreshold float64, +) ([]PodPredictionResult, bool) { + logger := log.FromContext(ctx) + + eligible := make([]PodPredictionResult, 0, len(candidates)) + for _, p := range candidates { + if p.PrefixCacheScore >= prefixStickyThreshold { + eligible = append(eligible, p) + } + } + + // No eligible sticky pods? Explore (no gating). + if len(eligible) == 0 { + return candidates, false + } + + // ε-exploration branch + if r.Float64() < EpsilonExploreSticky { + logger.V(logutil.DEBUG).Info("ε-greedy: exploring (ignoring affinity gate)", + "path", label, "epsilon", EpsilonExploreSticky, "eligibleCount", len(eligible)) + return candidates, false + } + + logger.V(logutil.DEBUG).Info("ε-greedy: exploiting (apply affinity gate)", + "path", label, "threshold", prefixStickyThreshold, "eligibleCount", len(eligible), "total", len(candidates)) + return eligible, true +} + +func (s *SLOAwareRouter) Score(ctx context.Context, state *schedulingtypes.CycleState, request *schedulingtypes.LLMRequest, pods []schedulingtypes.Pod) map[schedulingtypes.Pod]float64 { + logger := log.FromContext(ctx) + if s.latencypredictor == nil { + logger.V(logutil.DEBUG).Info("SLOAwareRouter: no predictor configured, returning nil scores") + return nil + } + + sloCtx := s.getOrMakeSLORequestContext(request) + + var err error + // get request slos + // Get Request SLOs from request header + sloCtx.TTFTSLO, _, err = parseFloatHeader(*request, TTFTSLOHeaderKey) + if err != nil { + logger.V(logutil.DEBUG).Error(errutil.Error{Code: errutil.BadRequest, Msg: fmt.Sprintf("%v must be a float: %v", TTFTSLOHeaderKey, err)}, "SLOAwareRouter: Error parsing TTFT SLO from header") + } + + sloCtx.AvgTPOTSLO, _, err = parseFloatHeader(*request, TPOTSLOHeaderKey) + if err != nil { + logger.V(logutil.DEBUG).Error(errutil.Error{Code: errutil.BadRequest, Msg: fmt.Sprintf("%v must be a float: %v", TPOTSLOHeaderKey, err)}, "SLOAwareRouter: Error parsing TPOT SLO from header") + } + sloCtx.PredictorBasedScheduling, err = parseBoolHeader(*request, "x-prediction-based-scheduling") + if err != nil { + logger.V(logutil.DEBUG).Error(errutil.Error{Code: errutil.BadRequest, Msg: fmt.Sprintf("x-prediction-based-scheduling must be a bool: %v", err)}, "SLOAwareRouter: Error parsing PredictorBasedScheduling from header") + } + + // Check if SLOs are provided + if !sloCtx.PredictorBasedScheduling { + logger.V(logutil.DEBUG).Info("PredictorBasedScheduling turned off, skipping prediction-based filtering") + return nil + } + + predictions := s.generatePredictions(ctx, state, request, sloCtx, pods) + s.updateRequestContextWithPredictions(sloCtx, predictions) + + allPreds := append([]PodPredictionResult(nil), predictions...) + + // Initialize scores map with all pods having score 0 + scores := make(map[schedulingtypes.Pod]float64, len(pods)) + for _, pod := range pods { + scores[pod] = 0 + } + + source := rand.NewSource(time.Now().UnixNano()) + r := rand.New(source) + allPreds, sticky := s.epsilonGreedyAffinityGate(ctx, allPreds, r, "overall", AffinityGateTauGlobal) + + // Check if all pods are invalid and all have running requests + allPodsInvalid := true + allPodsHaveRunningRequests := true + + for _, pred := range allPreds { + if pred.IsValid { + allPodsInvalid = false + } + + runningRequestCount := s.getPodRunningRequestCount(pred.Pod) + if runningRequestCount == 0 { + allPodsHaveRunningRequests = false + } + } + + // Set HasValidPod to false if all pods are invalid and all have running requests + if allPodsInvalid && allPodsHaveRunningRequests && !sticky { + sloCtx.HasValidPod = false + logger.V(logutil.DEBUG).Info("All pods are invalid and have running requests, setting HasValidPod to false") + } + + // 2) Tiered selection: positive headroom pods get 99% probability, negative get 1% + var posHeadroomPods, negHeadroomPods []PodPredictionResult + for _, p := range allPreds { + // A pod has positive headroom only if BOTH TTFT and TPOT have positive headroom + if p.Headroom > 0 && p.TTFTHeadroom > 0 { + posHeadroomPods = append(posHeadroomPods, p) + } else { + // A pod has negative headroom if EITHER TTFT or TPOT has negative/zero headroom + negHeadroomPods = append(negHeadroomPods, p) + } + } + + logger.V(logutil.DEBUG).Info("Pod headroom distribution", + "positivePods", len(posHeadroomPods), + "negativePods", len(negHeadroomPods)) + + var selectedPod schedulingtypes.Pod + + if s.headroomStrategy == HeadroomStrategyCompositeOnly { + logger.V(logutil.DEBUG).Info("Selecting from composite scores only") + selectedPod = s.selectFromCompositeScores(ctx, allPreds, r, HeadroomStrategyCompositeOnly) + } else if len(posHeadroomPods) > 0 && len(negHeadroomPods) > 0 { + // 99% chance to select from positive headroom pods, 1% from negative + if r.Float64() < EpsilonExploreNeg { + logger.V(logutil.DEBUG).Info("Selecting from negative headroom pods (1% chance)") + selectedPod = s.selectFromNegativeHeadroomPods(ctx, negHeadroomPods, r) + } else { + logger.V(logutil.DEBUG).Info("Selecting from positive headroom pods (99% chance)") + selectedPod = s.selectFromPositiveHeadroomPods(ctx, posHeadroomPods, r) + } + } else if len(posHeadroomPods) > 0 { + // If only positive headroom pods exist, select from them + logger.V(logutil.DEBUG).Info("Only positive headroom pods available") + selectedPod = s.selectFromPositiveHeadroomPods(ctx, posHeadroomPods, r) + } else if len(negHeadroomPods) > 0 { + // If only negative headroom pods exist, select from them + logger.V(logutil.DEBUG).Info("Only negative headroom pods available") + selectedPod = s.selectFromNegativeHeadroomPods(ctx, negHeadroomPods, r) + } else if len(allPreds) > 0 { + // fallback - select randomly from valid pods + logger.V(logutil.DEBUG).Info("No headroom pods available, selecting randomly from valid pods") + selectedPod = allPreds[r.Intn(len(allPreds))].Pod + } else { + // No valid pods - return all zeros + logger.V(logutil.DEBUG).Info("No valid pods available, returning all zero scores") + return scores + } + + // Set score = 1 for selected pod, 0 for all others + if selectedPod != nil { + scores[selectedPod] = 1 + logger.V(logutil.DEBUG).Info("Selected pod for scheduling", "pod", selectedPod.GetPod().String()) + } + + s.setSLOContextForRequest(request, sloCtx) + + return scores +} + +func (t *SLOAwareRouter) getOrMakeSLORequestContext(request *schedulingtypes.LLMRequest) *SLORequestContext { + sloCtx, err := t.getSLOContextForRequest(request) + if err != nil { + sloCtx = NewSLORequestContext(request) + } + return sloCtx +} + +func (s *SLOAwareRouter) getPrefixCacheScoreForPod(ctx context.Context, cycleState *schedulingtypes.CycleState, pod schedulingtypes.Pod) float64 { + log.FromContext(ctx).V(logutil.DEBUG).Info("Running getPrefixCacheScoreForPod, getting prefix cache score for pod", "pod", pod.GetPod().String()) + plugintype := prefix.PrefixCachePluginType + pluginname := prefix.PrefixCachePluginType + cycleStateKey := (plugins.TypedName{Type: plugintype, Name: pluginname}).String() + stateData, err := cycleState.Read(plugins.StateKey(cycleStateKey)) + + log.FromContext(ctx).V(logutil.DEBUG).Info("Reading prefix cache state from cycle state", "stateKey", cycleStateKey) + + if err != nil { + // The prefix cache plugin might not be enabled, which is a valid scenario. + log.FromContext(ctx).V(logutil.DEBUG).Info("Prefix cache state not found in cycle state, returning prefix cache score of 0.0", "pod", pod.GetPod().String()) + return 0.0 + } + + prefixCacheState, ok := stateData.(*prefix.SchedulingContextState) + if !ok { + // This should not happen if the plugin is configured correctly. + log.FromContext(ctx).Error(fmt.Errorf("unexpected state type: %T", stateData), "failed to read prefix cache state") + return 0.0 + } + + total := len(prefixCacheState.PrefixHashes) + if total == 0 { + // if the request has no prefixes, return 0.0 + log.FromContext(ctx).V(logutil.DEBUG).Info("No prefixes found in request, returning prefix cache score of 0.0") + return 0.0 + } + + matchLen := prefixCacheState.PrefixCacheServers[prefix.ServerID(pod.GetPod().NamespacedName)] + log.FromContext(ctx).V(logutil.DEBUG).Info("Prefix cache score for pod", "pod", pod.GetPod().String(), "matchLen", matchLen, "totalPrefixes", total) + return float64(matchLen) / float64(total) +} diff --git a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/selection.go b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/selection.go new file mode 100644 index 000000000..34618ce19 --- /dev/null +++ b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/selection.go @@ -0,0 +1,381 @@ +/* +© 2025 The Kubernetes Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License. +*/ + +// Package requestcontrol contains helpers to decouple latency-predictor logic. +package slo_aware_router + +import ( + "context" + "math" + "math/rand" + + "k8s.io/apimachinery/pkg/types" + "sigs.k8s.io/controller-runtime/pkg/log" + schedulingtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" + logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" +) + +// selectFromPositiveHeadroomPods selects a pod from positive headroom pods using headroom strategy +// Updated to incorporate TTFTHeadroom with a configurable blend vs TPOT headroom. +func (s *SLOAwareRouter) selectFromPositiveHeadroomPods(ctx context.Context, posHeadroomPods []PodPredictionResult, r *rand.Rand) schedulingtypes.Pod { + logger := log.FromContext(ctx) + + if len(posHeadroomPods) == 1 { + return posHeadroomPods[0].Pod + } + + // Apply perfect stickiness (with exploration) + candidates, sticky := s.epsilonGreedyAffinityGate(ctx, posHeadroomPods, r, "positive", AffinityGateTau) + + // If perfect stickiness collapsed us to a single pod, short-circuit + if sticky && len(candidates) == 1 { + return candidates[0].Pod + } + switch s.headroomStrategy { + case HeadroomStrategyCompositeMost: + return s.selectFromCompositeScores(ctx, candidates, r, HeadroomStrategyCompositeMost) + case HeadroomStrategyCompositeLeast: + return s.selectFromCompositeScores(ctx, candidates, r, HeadroomStrategyCompositeLeast) + } + + // Find min/max for TPOT (Headroom) and TTFTHeadroom across positive pods to normalize to [0,1] + minTPOTH, maxTPOTH := math.MaxFloat64, -math.MaxFloat64 + minTTFTH, maxTTFTH := math.MaxFloat64, -math.MaxFloat64 + + for _, p := range candidates { + if p.Headroom < minTPOTH { + minTPOTH = p.Headroom + } + if p.Headroom > maxTPOTH { + maxTPOTH = p.Headroom + } + if p.TTFTHeadroom < minTTFTH { + minTTFTH = p.TTFTHeadroom + } + if p.TTFTHeadroom > maxTTFTH { + maxTTFTH = p.TTFTHeadroom + } + } + + tpotRange := maxTPOTH - minTPOTH + ttftRange := maxTTFTH - minTTFTH + + // Precompute blend weights (renormalize if user sets both to 0) + alpha := HeadroomTTFTWeight + beta := HeadroomTPOTWeight + if alpha+beta <= 0 { + alpha = 1.0 + beta = 0.0 + } + sum := alpha + beta + alpha /= sum + beta /= sum + + logger.V(logutil.DEBUG).Info("Positive headroom normalization ranges", + "minTPOTHeadroom", minTPOTH, "maxTPOTHeadroom", maxTPOTH, + "minTTFTHeadroom", minTTFTH, "maxTTFTHeadroom", maxTTFTH, + "alphaTTFT", alpha, "betaTPOT", beta, "strategy", s.headroomStrategy) + + // Calculate weights for weighted random selection + weightedChoices := make([]Choice, 0, len(candidates)) + total := 0 + + for _, p := range candidates { + // Normalize to [0,1] within the cohort + nTPOTH := 0.5 + if tpotRange > eps { + nTPOTH = (p.Headroom - minTPOTH) / (tpotRange + eps) + } + nTTFTH := 0.5 + if ttftRange > eps { + nTTFTH = (p.TTFTHeadroom - minTTFTH) / (ttftRange + eps) + } + + // Blend: larger combined -> "safer"; smaller -> "tighter packing" + combined := alpha*nTTFTH + beta*nTPOTH + + // Map to integer weights + var w int + switch s.headroomStrategy { + case HeadroomStrategyLeast: + // prefer smaller combined headroom (pack closer to limits) + w = int((1.0-combined)*float64(Wmax-minWeight)) + minWeight + 1 + case HeadroomStrategyMost: + // prefer larger combined headroom (more conservative / spread) + w = int(combined*float64(Wmax-minWeight)) + minWeight + 1 + default: + // Fallback to least + w = int((1.0-combined)*float64(Wmax-minWeight)) + minWeight + 1 + } + + weightedChoices = append(weightedChoices, Choice{PodName: p.Pod, Weight: w}) + total += w + + logger.V(logutil.TRACE).Info("Positive headroom blended weight", + "pod", p.Pod.GetPod().String(), + "ttftHeadroom", p.TTFTHeadroom, "normTTFTHeadroom", nTTFTH, + "tpotHeadroom", p.Headroom, "normTPOTHeadroom", nTPOTH, + "combined", combined, "weight", w) + } + + return s.performWeightedRandomSelection(weightedChoices, total, candidates, r) + +} + +// selectFromNegativeHeadroomPods selects a pod from negative headroom pods using hierarchical TTFT/TPOT logic +// Modified to strictly prefer pods with 0 running requests +func (s *SLOAwareRouter) selectFromNegativeHeadroomPods(ctx context.Context, negHeadroomPods []PodPredictionResult, r *rand.Rand) schedulingtypes.Pod { + logger := log.FromContext(ctx) + + if len(negHeadroomPods) == 1 { + return negHeadroomPods[0].Pod + } + + // First, separate pods by running request count + var zeroRunningRequestPods, nonZeroRunningRequestPods []PodPredictionResult + + for _, p := range negHeadroomPods { + runningRequestCount := s.getPodRunningRequestCount(p.Pod) + if runningRequestCount == 0 { + zeroRunningRequestPods = append(zeroRunningRequestPods, p) + } else { + nonZeroRunningRequestPods = append(nonZeroRunningRequestPods, p) + } + } + + logger.V(logutil.DEBUG).Info("Negative headroom pods by running request count", + "zeroRunningRequests", len(zeroRunningRequestPods), + "nonZeroRunningRequests", len(nonZeroRunningRequestPods)) + + // If we have pods with 0 running requests, strictly prefer them + if len(zeroRunningRequestPods) > 0 { + logger.V(logutil.DEBUG).Info("Selecting from pods with zero running requests") + return s.selectFromNegativeHeadroomPodsInternal(ctx, zeroRunningRequestPods, r) + } + + // Otherwise, fall back to pods with running requests + logger.V(logutil.DEBUG).Info("No pods with zero running requests, selecting from pods with running requests") + return s.selectFromNegativeHeadroomPodsInternal(ctx, nonZeroRunningRequestPods, r) +} + +// selectFromNegativeHeadroomPodsInternal handles the actual selection logic for negative headroom pods +func (s *SLOAwareRouter) selectFromNegativeHeadroomPodsInternal(ctx context.Context, negHeadroomPods []PodPredictionResult, r *rand.Rand) schedulingtypes.Pod { + if len(negHeadroomPods) == 1 { + return negHeadroomPods[0].Pod + } + + // Apply perfect stickiness (with exploration) + candidates, sticky := s.epsilonGreedyAffinityGate(ctx, negHeadroomPods, r, "negative", AffinityGateTau) + + // If perfect stickiness collapsed us to a single pod, short-circuit + if sticky && len(candidates) == 1 { + return candidates[0].Pod + } + + switch s.headroomStrategy { + case HeadroomStrategyCompositeMost: + return s.selectFromCompositeScores(ctx, candidates, r, HeadroomStrategyCompositeMost) + case HeadroomStrategyCompositeLeast: + return s.selectFromCompositeScores(ctx, candidates, r, HeadroomStrategyCompositeMost) + } + + // Build weighted choices for selection + weightedChoices := make([]Choice, 0, len(candidates)) + total := 0 + + s.handleNegativeHeadroomPodsHierarchical(ctx, candidates, &weightedChoices, &total, minWeight) + + // Perform weighted random selection + return s.performWeightedRandomSelection(weightedChoices, total, candidates, r) +} + +// weightPodsByBlendedDeficit applies blended weighting using TTFT and TPOT deficits. +// Lower blended deficit => higher weight. +func (ps *SLOAwareRouter) weightPodsByBlendedDeficit( + ctx context.Context, + pods []PodPredictionResult, + choices *[]Choice, + total *int, + minWeight int, + alpha, beta float64, // weights for TTFT and TPOT deficits + category string, +) { + logger := log.FromContext(ctx) + if len(pods) == 0 { + return + } + + const Wrange = 80 + const eps = 1e-9 + + // Compute raw deficits (only when headroom is negative) + type deficits struct { + pod PodPredictionResult + ttftDef float64 + tpotDef float64 + } + defs := make([]deficits, 0, len(pods)) + + minTTFT, maxTTFT := math.MaxFloat64, -math.MaxFloat64 + minTPOT, maxTPOT := math.MaxFloat64, -math.MaxFloat64 + + for _, p := range pods { + ttftDef := 0.0 + if p.TTFTHeadroom < 0 { + ttftDef = -p.TTFTHeadroom + } + tpotDef := 0.0 + if p.Headroom < 0 { + tpotDef = -p.Headroom + } + defs = append(defs, deficits{pod: p, ttftDef: ttftDef, tpotDef: tpotDef}) + + if ttftDef < minTTFT { + minTTFT = ttftDef + } + if ttftDef > maxTTFT { + maxTTFT = ttftDef + } + if tpotDef < minTPOT { + minTPOT = tpotDef + } + if tpotDef > maxTPOT { + maxTPOT = tpotDef + } + } + + ttftRange := maxTTFT - minTTFT + tpotRange := maxTPOT - minTPOT + + // Normalize alpha/beta + if alpha+beta <= 0 { + alpha, beta = 1.0, 0.0 + } else { + sum := alpha + beta + alpha /= sum + beta /= sum + } + + logger.V(logutil.DEBUG).Info("Negative headroom blended deficits", + "category", category, + "minTTFTDef", minTTFT, "maxTTFTDef", maxTTFT, + "minTPOTDef", minTPOT, "maxTPOTDef", maxTPOT, + "alphaTTFT", alpha, "betaTPOT", beta, "podCount", len(pods)) + + for _, d := range defs { + // Normalize deficits to [0,1] within this bucket (0 = best / least violation) + nTTFT := 0.0 + if ttftRange > eps { + nTTFT = (d.ttftDef - minTTFT) / (ttftRange + eps) + } + nTPOT := 0.0 + if tpotRange > eps { + nTPOT = (d.tpotDef - minTPOT) / (tpotRange + eps) + } + + // Blended "badness": higher = worse violation + blended := alpha*nTTFT + beta*nTPOT + + // Convert to selection weight: lower badness -> higher weight + // Ensure a floor so no pod is completely excluded within the bucket. + w := int((1.0-blended)*float64(Wrange)) + minWeight + 1 + + *choices = append(*choices, Choice{PodName: d.pod.Pod, Weight: w}) + *total += w + + logger.V(logutil.TRACE).Info("Negative bucket blended weighting", + "pod", d.pod.Pod.GetPod().String(), + "ttftDef", d.ttftDef, "tpotDef", d.tpotDef, + "normTTFT", nTTFT, "normTPOT", nTPOT, + "blendedBadness", blended, "weight", w) + } +} + +func (s *SLOAwareRouter) handleNegativeHeadroomPodsHierarchical( + ctx context.Context, + negHeadroomPods []PodPredictionResult, + choices *[]Choice, + total *int, + minWeightForNegative int, +) { + logger := log.FromContext(ctx) + + // Categorize pods by their headroom status + var negTTFTNegTPOT, negTTFTNonNegTPOT, nonNegTTFTNegTPOT, nonNegTTFTNonNegTPOT []PodPredictionResult + + for _, p := range negHeadroomPods { + if p.TTFTHeadroom < 0 && p.Headroom < 0 { + negTTFTNegTPOT = append(negTTFTNegTPOT, p) + } else if p.TTFTHeadroom < 0 && p.Headroom >= 0 { + negTTFTNonNegTPOT = append(negTTFTNonNegTPOT, p) + } else if p.TTFTHeadroom >= 0 && p.Headroom < 0 { + nonNegTTFTNegTPOT = append(nonNegTTFTNegTPOT, p) + } else { + nonNegTTFTNonNegTPOT = append(nonNegTTFTNonNegTPOT, p) + } + } + + logger.V(logutil.DEBUG).Info("Hierarchical negative headroom pod distribution", + "totalNegative", len(negHeadroomPods), + "negTTFT_negTPOT", len(negTTFTNegTPOT), + "negTTFT_nonNegTPOT", len(negTTFTNonNegTPOT), + "nonNegTTFT_negTPOT", len(nonNegTTFTNegTPOT), + "nonNegTTFT_nonNegTPOT", len(nonNegTTFTNonNegTPOT)) + + // Priority 1: both TTFT and TPOT negative -> blended deficits (both active) + if len(negTTFTNegTPOT) > 0 { + s.weightPodsByBlendedDeficit(ctx, negTTFTNegTPOT, choices, total, minWeightForNegative, + NegHeadroomTTFTWeight, NegHeadroomTPOTWeight, "both_negative") + } + + // Priority 2: TTFT negative, TPOT non-negative -> blended still works (TPOT deficit=0) + if len(negTTFTNonNegTPOT) > 0 { + s.weightPodsByBlendedDeficit(ctx, negTTFTNonNegTPOT, choices, total, minWeightForNegative, + NegHeadroomTTFTWeight, NegHeadroomTPOTWeight, "ttft_negative") + } + + // Priority 3: TTFT non-negative, TPOT negative -> blended (TTFT deficit=0) + if len(nonNegTTFTNegTPOT) > 0 { + s.weightPodsByBlendedDeficit(ctx, nonNegTTFTNegTPOT, choices, total, minWeightForNegative, + NegHeadroomTTFTWeight, NegHeadroomTPOTWeight, "tpot_negative") + } + + // Priority 4: edge-case bucket -> minimal weight + for _, p := range nonNegTTFTNonNegTPOT { + *choices = append(*choices, Choice{PodName: p.Pod, Weight: minWeightForNegative}) + *total += minWeightForNegative + } +} + +func (s *SLOAwareRouter) getPodMinTPOTSLO(pod schedulingtypes.Pod) float64 { + podName := types.NamespacedName{ + Name: pod.GetPod().NamespacedName.Name, + Namespace: pod.GetPod().NamespacedName.Namespace, + } + if runningReqs, ok := s.runningRequestLists[podName]; ok && runningReqs.GetSize() > 0 { + if topReq := runningReqs.Peek(); topReq != nil { + return topReq.TPOT + } + } + return 0 // no running requests or no TPOT SLOs +} + +func (s *SLOAwareRouter) getPodRunningRequestCount(pod schedulingtypes.Pod) int { + podName := types.NamespacedName{ + Name: pod.GetPod().NamespacedName.Name, + Namespace: pod.GetPod().NamespacedName.Namespace, + } + if runningReqs, ok := s.runningRequestLists[podName]; ok { + return runningReqs.GetSize() + } + return 0 // no running requests +} diff --git a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/types.go b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/types.go new file mode 100644 index 000000000..036b5ee76 --- /dev/null +++ b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/types.go @@ -0,0 +1,53 @@ +/* +© 2025 The Kubernetes Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License. +*/ + +// Package requestcontrol contains helpers to decouple latency-predictor logic. +package slo_aware_router + +import schedulingtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" + +type HeadroomStrategy string + +type Choice struct { + PodName schedulingtypes.Pod + Weight int +} + +const ( + // HeadroomStrategyLeast prioritizes pods with least positive headroom (better packing) + HeadroomStrategyLeast HeadroomStrategy = "least" + // HeadroomStrategyMost prioritizes pods with most positive headroom (more conservative) + HeadroomStrategyMost HeadroomStrategy = "most" + + HeadroomStrategyCompositeLeast HeadroomStrategy = "composite-least" + HeadroomStrategyCompositeMost HeadroomStrategy = "composite-most" + HeadroomStrategyCompositeOnly HeadroomStrategy = "composite-only" + + // TTFT header string + TTFTSLOHeaderKey = "x-slo-ttft-ms" + // TPOT header string + TPOTSLOHeaderKey = "x-slo-tpot-ms" +) + +const ( + SLOAwareRouterPluginType = "slo-aware-routing" + eps = 1e-9 + Wmax = 100 + minWeight = 1 +) + +type PodSelectionMode string + +const ( + PodSelectionLinear PodSelectionMode = "linear" // weighted-random (current behavior) + PodSelectionMax PodSelectionMode = "max" // pick argmax weight +)