Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions cmd/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ func main() {

startCustomResourceController(ctx, mgr, metricsRecorder, allocator, portAllocator, nodeExpander)

startHttpServerForTFClient(ctx, kc, portAllocator, indexAllocator, allocator, scheduler, mgr.Elected())
startHttpServerForTFClient(ctx, kc, portAllocator, indexAllocator, allocator, scheduler, nodeExpander, mgr.Elected())

// +kubebuilder:scaffold:builder
addHealthCheckAPI(mgr)
Expand Down Expand Up @@ -306,6 +306,7 @@ func startHttpServerForTFClient(
indexAllocator *indexallocator.IndexAllocator,
allocator *gpuallocator.GpuAllocator,
scheduler *scheduler.Scheduler,
nodeExpander *expander.NodeExpander,
leaderChan <-chan struct{},
) {
client, err := client.NewWithWatch(kc, client.Options{Scheme: scheme})
Expand Down Expand Up @@ -333,8 +334,13 @@ func startHttpServerForTFClient(
setupLog.Error(err, "failed to create allocator info router")
os.Exit(1)
}
nodeScalerInfoRouter, err := router.NewNodeScalerInfoRouter(ctx, nodeExpander)
if err != nil {
setupLog.Error(err, "failed to create node scaler info router")
os.Exit(1)
}
httpServer := server.NewHTTPServer(
connectionRouter, assignHostPortRouter, assignIndexRouter, allocatorInfoRouter, leaderChan,
connectionRouter, assignHostPortRouter, assignIndexRouter, allocatorInfoRouter, nodeScalerInfoRouter, leaderChan,
)
go func() {
err := httpServer.Run()
Expand Down
2 changes: 1 addition & 1 deletion internal/controller/pod_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ func (r *PodReconciler) Reconcile(ctx context.Context, req ctrl.Request) (ctrl.R
pod := &corev1.Pod{}
if err := r.Get(ctx, req.NamespacedName, pod); err != nil {
if errors.IsNotFound(err) {
r.Expander.RemovePreSchedulePod(req.Name, true)
_ = r.Expander.RemovePreSchedulePod(req.Name, true)
r.Allocator.DeallocByPodIdentifier(ctx, req.NamespacedName)
metrics.RemoveWorkerMetrics(req.Name, time.Now())
log.Info("Released GPU resources when pod deleted", "pod", req.NamespacedName)
Expand Down
145 changes: 104 additions & 41 deletions internal/scheduler/expander/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,16 +34,17 @@ const (
)

type NodeExpander struct {
client client.Client
scheduler *scheduler.Scheduler
allocator *gpuallocator.GpuAllocator
logger klog.Logger
inFlightNodes map[string][]*tfv1.GPU
preSchedulePods map[string]*tfv1.AllocRequest
preScheduleTimers map[string]*time.Timer
eventRecorder record.EventRecorder
mu sync.RWMutex
ctx context.Context
client client.Client
scheduler *scheduler.Scheduler
allocator *gpuallocator.GpuAllocator
logger klog.Logger
inFlightNodes map[string][]*tfv1.GPU
inFlightNodeClaims sync.Map
preSchedulePods map[string]*tfv1.AllocRequest
preScheduleTimers map[string]*time.Timer
eventRecorder record.EventRecorder
mu sync.RWMutex
ctx context.Context
}

func NewNodeExpander(
Expand All @@ -54,15 +55,16 @@ func NewNodeExpander(
) *NodeExpander {

expander := &NodeExpander{
client: allocator.Client,
scheduler: scheduler,
allocator: allocator,
logger: log.FromContext(ctx).WithValues("component", "NodeExpander"),
inFlightNodes: make(map[string][]*tfv1.GPU, 10),
preSchedulePods: make(map[string]*tfv1.AllocRequest, 20),
preScheduleTimers: make(map[string]*time.Timer, 20),
eventRecorder: recorder,
ctx: ctx,
client: allocator.Client,
scheduler: scheduler,
allocator: allocator,
logger: log.FromContext(ctx).WithValues("component", "NodeExpander"),
inFlightNodes: make(map[string][]*tfv1.GPU, 10),
preSchedulePods: make(map[string]*tfv1.AllocRequest, 20),
preScheduleTimers: make(map[string]*time.Timer, 20),
inFlightNodeClaims: sync.Map{},
eventRecorder: recorder,
ctx: ctx,
}
allocator.RegisterBindHandler(func(req *tfv1.AllocRequest) {
obj := &corev1.ObjectReference{
Expand All @@ -73,15 +75,68 @@ func NewNodeExpander(
UID: req.PodMeta.UID,
ResourceVersion: req.PodMeta.ResourceVersion,
}
recorder.Eventf(obj, corev1.EventTypeNormal, "NodeExpansionCheck",
"new node provisioned and pod scheduled successfully")
expander.logger.Info("new node provisioned and pod scheduled successfully",
"namespace", req.PodMeta.Namespace, "pod", req.PodMeta.Name)
expander.RemovePreSchedulePod(req.PodMeta.Name, true)

removed := expander.RemovePreSchedulePod(req.PodMeta.Name, true)
if removed {
recorder.Eventf(obj, corev1.EventTypeNormal, "NodeExpansionCheck",
"new node provisioned and pod scheduled successfully")
}
})

// Start checking inFlightNodeClaims every minute to avoid stuck in inFlightNodes
go func() {
for {
time.Sleep(time.Minute)
expander.inFlightNodeClaims.Range(func(key, _ interface{}) bool {
karpenterNodeClaim := &karpv1.NodeClaim{}
if err := expander.client.Get(expander.ctx, client.ObjectKey{Name: key.(string)}, karpenterNodeClaim); err != nil {
if errors.IsNotFound(err) {
expander.inFlightNodeClaims.Delete(key)
expander.RemoveInFlightNode(key.(string))
expander.logger.Info("karpenter node claim not found, remove from inFlightNodeClaims and inFlightNodes", "nodeClaimName", key.(string))
return true
}
expander.logger.Error(err, "failed to get karpenter node claim", "nodeClaimName", key.(string))
return true
}
if !karpenterNodeClaim.DeletionTimestamp.IsZero() {
expander.inFlightNodeClaims.Delete(key)
expander.RemoveInFlightNode(key.(string))
expander.logger.Info("karpenter node claim is deleted, remove from inFlightNodeClaims and inFlightNodes", "nodeClaimName", key.(string))
return true
}
expander.mu.RLock()
defer expander.mu.RUnlock()
if _, ok := expander.inFlightNodes[karpenterNodeClaim.Status.NodeName]; !ok {
expander.inFlightNodeClaims.Delete(key)
expander.logger.Info("karpenter node claim has been provisioned, remove from inFlightNodeClaims", "nodeClaimName", key.(string))
return true
}
return true
})
}
}()

return expander
}

func (e *NodeExpander) GetNodeScalerInfo() any {
e.mu.RLock()
defer e.mu.RUnlock()

inFlightNodeClaimSnapshot := make(map[string]any)
e.inFlightNodeClaims.Range(func(key, value interface{}) bool {
inFlightNodeClaimSnapshot[key.(string)] = value
return true
})
return map[string]any{
"inFlightNodes": e.inFlightNodes,
"inFlightNodeClaims": inFlightNodeClaimSnapshot,
"preSchedulePods": e.preSchedulePods,
"preScheduleTimerNum": len(e.preScheduleTimers),
}
}

func (e *NodeExpander) ProcessExpansion(ctx context.Context, pod *corev1.Pod) error {
if pod == nil {
return fmt.Errorf("pod cannot be nil")
Expand Down Expand Up @@ -196,11 +251,11 @@ func (e *NodeExpander) addInFlightNodeAndPreSchedulePod(allocRequest *tfv1.Alloc
err := e.client.Get(e.ctx, client.ObjectKey{Name: podMeta.Name, Namespace: podMeta.Namespace}, currentPod)
if err != nil {
if errors.IsNotFound(err) || !currentPod.DeletionTimestamp.IsZero() {
e.RemovePreSchedulePod(podMeta.Name, false)
_ = e.RemovePreSchedulePod(podMeta.Name, false)
}
e.logger.Error(err, "failed to get pod for node expansion check",
"namespace", podMeta.Namespace, "pod", podMeta.Name)
e.RemovePreSchedulePod(podMeta.Name, false)
_ = e.RemovePreSchedulePod(podMeta.Name, false)
return
}
if currentPod.Spec.NodeName != "" {
Expand All @@ -209,14 +264,14 @@ func (e *NodeExpander) addInFlightNodeAndPreSchedulePod(allocRequest *tfv1.Alloc
"new node provisioned and pod scheduled successfully")
e.logger.Info("new node provisioned and pod scheduled successfully",
"namespace", podMeta.Namespace, "pod", podMeta.Name)
e.RemovePreSchedulePod(podMeta.Name, false)
_ = e.RemovePreSchedulePod(podMeta.Name, false)
} else {
// not scheduled, record warning event and remove pre-scheduled pod
e.eventRecorder.Eventf(currentPod, corev1.EventTypeWarning, "NodeExpansionCheck",
"failed to schedule pod after 10 minutes")
e.logger.Info("failed to schedule pod after 10 minutes",
"namespace", podMeta.Namespace, "pod", podMeta.Name)
e.RemovePreSchedulePod(podMeta.Name, false)
_ = e.RemovePreSchedulePod(podMeta.Name, false)
}
})
e.preScheduleTimers[podMeta.Name] = timer
Expand All @@ -228,25 +283,32 @@ func (e *NodeExpander) RemoveInFlightNode(nodeName string) {
return
}
e.mu.Lock()
delete(e.inFlightNodes, nodeName)
e.logger.Info("Removed in-flight node", "node", nodeName, "remaining inflight nodes", len(e.inFlightNodes))
if _, ok := e.inFlightNodes[nodeName]; ok {
delete(e.inFlightNodes, nodeName)
e.logger.Info("Removed in-flight node", "node", nodeName, "remaining inflight nodes", len(e.inFlightNodes))
}
e.mu.Unlock()
}

func (e *NodeExpander) RemovePreSchedulePod(podName string, stopTimer bool) {
func (e *NodeExpander) RemovePreSchedulePod(podName string, stopTimer bool) bool {
if e == nil {
return
return false
}
e.mu.Lock()
defer e.mu.Unlock()
if stopTimer {
if timer, ok := e.preScheduleTimers[podName]; ok {
timer.Stop()
}
}
delete(e.preScheduleTimers, podName)
delete(e.preSchedulePods, podName)
e.logger.Info("Removed pre-scheduled pod", "pod", podName, "remaining pre-scheduled pods", len(e.preSchedulePods))
e.mu.Unlock()

if _, ok := e.preSchedulePods[podName]; ok {
delete(e.preSchedulePods, podName)
e.logger.Info("Removed pre-scheduled pod", "pod", podName, "remaining pre-scheduled pods", len(e.preSchedulePods))
return true
}
return false
}

func (e *NodeExpander) prepareNewNodesForScheduleAttempt(
Expand Down Expand Up @@ -327,7 +389,7 @@ func (e *NodeExpander) checkGPUFitWithInflightNodes(pod *corev1.Pod, gpus []*tfv
if !preScheduledPodPreAllocated {
e.logger.Info("[Warning] pre-scheduled pod can not set into InFlight node anymore, remove queue and retry later",
"pod", alloc.PodMeta.Name, "namespace", alloc.PodMeta.Namespace)
e.RemovePreSchedulePod(alloc.PodMeta.Name, true)
_ = e.RemovePreSchedulePod(alloc.PodMeta.Name, true)
}
}

Expand Down Expand Up @@ -409,13 +471,13 @@ func (e *NodeExpander) createGPUNodeClaim(ctx context.Context, pod *corev1.Pod,
e.logger.Info("node is not owned by any known provisioner, skip expansion", "node", preparedNode.Name)
return fmt.Errorf("node is not owned by any known provisioner, skip expansion")
}
e.logger.Info("start expanding node from existing template node", "tmplNode", preparedNode.Name)
e.logger.Info("start expanding node from existing template node", "newNodeClaimName", preparedNode.Name)
if isKarpenterNodeClaim {
// Check if controllerMeta's parent is GPUNodeClaim using unstructured object
return e.handleKarpenterNodeClaim(ctx, pod, preparedNode, controlledBy)
} else if isGPUNodeClaim {
// Running in Provisioning mode, clone the parent GPUNodeClaim and apply
e.logger.Info("node is controlled by GPUNodeClaim, cloning another to expand node", "tmplNode", preparedNode.Name)
e.logger.Info("node is controlled by GPUNodeClaim, cloning another to expand node", "newNode", preparedNode.Name)
return e.cloneGPUNodeClaim(ctx, pod, preparedNode, controlledBy)
}
return nil
Expand Down Expand Up @@ -450,12 +512,12 @@ func (e *NodeExpander) handleKarpenterNodeClaim(ctx context.Context, pod *corev1
if nodeClaimParent != nil && nodeClaimParent.Kind == tfv1.GPUNodeClaimKind {
// Parent is GPUNodeClaim, clone it and let cloudprovider module create real GPUNode
e.logger.Info("NodeClaim parent is GPUNodeClaim, cloning another to expand node",
"nodeClaimName", controlledBy.Name, "gpuNodeClaimParent", nodeClaimParent.Name)
"controlledBy", controlledBy.Name, "gpuNodeClaimParent", nodeClaimParent.Name)
return e.cloneGPUNodeClaim(ctx, pod, preparedNode, nodeClaimParent)
} else if hasNodePoolParent {
// owned by Karpenter node pool, create NodeClaim directly with special label identifier
e.logger.Info("NodeClaim owned by Karpenter Pool, creating Karpenter NodeClaim to expand node",
"nodeClaimName", controlledBy.Name)
"controlledBy", controlledBy.Name)
return e.createKarpenterNodeClaimDirect(ctx, pod, preparedNode, nodeClaim)
} else {
return fmt.Errorf("NodeClaim has no valid parent, can not expand node, should not happen")
Expand Down Expand Up @@ -527,9 +589,10 @@ func (e *NodeExpander) createKarpenterNodeClaimDirect(ctx context.Context, pod *
e.eventRecorder.Eventf(pod, corev1.EventTypeWarning, "NodeExpansionFailed", "failed to create new NodeClaim: %v", err)
return fmt.Errorf("failed to create NodeClaim: %w", err)
}

e.inFlightNodeClaims.Store(newNodeClaim.Name, true)
e.eventRecorder.Eventf(pod, corev1.EventTypeNormal, "NodeExpansionCompleted", "created new NodeClaim for node expansion: %s", newNodeClaim.Name)
e.logger.Info("created new NodeClaim for node expansion", "pod", pod.Name, "namespace", pod.Namespace, "nodeClaim", newNodeClaim.Name)

return nil
}

Expand Down
2 changes: 1 addition & 1 deletion internal/scheduler/expander/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ func testPreScheduledPodManagement(suite *NodeExpanderTestSuite) {
Expect(exists).To(BeTrue())

// Test removing pre-scheduled pod
suite.nodeExpander.RemovePreSchedulePod("test-pod", true)
_ = suite.nodeExpander.RemovePreSchedulePod("test-pod", true)

// Verify pre-scheduled pod is removed
suite.nodeExpander.mu.RLock()
Expand Down
26 changes: 26 additions & 0 deletions internal/server/router/node_scaler_info.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package router

import (
"context"
"net/http"

"github.com/NexusGPU/tensor-fusion/internal/scheduler/expander"
"github.com/gin-gonic/gin"
)

type NodeScalerInfoRouter struct {
nodeExpander *expander.NodeExpander
}

func NewNodeScalerInfoRouter(
ctx context.Context,
nodeExpander *expander.NodeExpander,
) (*NodeScalerInfoRouter, error) {
return &NodeScalerInfoRouter{nodeExpander: nodeExpander}, nil
}

func (r *NodeScalerInfoRouter) Get(ctx *gin.Context) {
ctx.JSON(http.StatusOK, gin.H{
"data": r.nodeExpander.GetNodeScalerInfo(),
})
}
2 changes: 2 additions & 0 deletions internal/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ func NewHTTPServer(
ahp *router.AssignHostPortRouter,
ai *router.AssignIndexRouter,
alc *router.AllocatorInfoRouter,
nsi *router.NodeScalerInfoRouter,
leaderChan <-chan struct{},
) *gin.Engine {

Expand Down Expand Up @@ -59,5 +60,6 @@ func NewHTTPServer(
apiGroup.GET("/config", func(ctx *gin.Context) {
ctx.JSON(http.StatusOK, gin.H{"config": config.GetGlobalConfig()})
})
apiGroup.GET("/node-scaler", nsi.Get)
return r
}