diff --git a/cmd/main.go b/cmd/main.go index 92021131..23cd69b8 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -217,7 +217,9 @@ func main() { // Initialize GPU allocator and set up watches allocator, portAllocator := startTensorFusionAllocators(ctx, mgr) - startWebhook(mgr, portAllocator) + // Create pricing provider for webhook + pricingProvider := pricing.NewStaticPricingProvider() + startWebhook(mgr, portAllocator, pricingProvider) scheduler := startScheduler(ctx, allocator, mgr) @@ -441,11 +443,15 @@ func startCustomResourceController( } } -func startWebhook(mgr manager.Manager, portAllocator *portallocator.PortAllocator) { +func startWebhook( + mgr manager.Manager, + portAllocator *portallocator.PortAllocator, + pricingProvider pricing.PricingProvider, +) { if os.Getenv(constants.EnableWebhookEnv) == constants.FalseStringValue { return } - if err := webhookcorev1.SetupPodWebhookWithManager(mgr, portAllocator); err != nil { + if err := webhookcorev1.SetupPodWebhookWithManager(mgr, portAllocator, pricingProvider); err != nil { setupLog.Error(err, "unable to create webhook", "webhook", "Pod") os.Exit(1) } diff --git a/internal/cloudprovider/pricing/pricing.go b/internal/cloudprovider/pricing/pricing.go index 33ee529f..e8854583 100644 --- a/internal/cloudprovider/pricing/pricing.go +++ b/internal/cloudprovider/pricing/pricing.go @@ -31,6 +31,7 @@ import ( "github.com/NexusGPU/tensor-fusion/internal/cloudprovider/types" "github.com/NexusGPU/tensor-fusion/internal/config" "github.com/NexusGPU/tensor-fusion/internal/constants" + "k8s.io/apimachinery/pkg/api/resource" "sigs.k8s.io/controller-runtime/pkg/log" ) @@ -39,11 +40,17 @@ const ( providerAzure = "azure" ) +// CompleteGPUInfo combines GpuInfo with VRAM information from instance data +type CompleteGPUInfo struct { + *config.GpuInfo + VRAMGigabytes int32 +} + // Global data initialized at package load time var ( globalAWSGPUInstanceData map[string]GPUNodeInstanceInfoAndPrice globalAzureGPUInstanceData map[string]GPUNodeInstanceInfoAndPrice - tflopsMap map[string]*config.GpuInfo + tflopsMap map[string]*CompleteGPUInfo ) var readyCh = make(chan struct{}) @@ -51,8 +58,9 @@ var initOnce sync.Once // PricingProvider provides pricing information and calculations for instance types type PricingProvider interface { - GetPricing(instanceType, capacityType tfv1.CapacityTypeEnum) (float64, bool) - GetGPUNodeInstanceTypeInfo(region string) ([]string, bool) + GetPricing(instanceType string, capacityType tfv1.CapacityTypeEnum, region string) (float64, bool) + GetRegionalGPUNodeInstanceTypes(region string) ([]types.GPUNodeInstanceInfo, bool) + GetGPUCapacityByModel(gpuModel string) (resource.Quantity, resource.Quantity, bool) } type GPUNodeInstanceInfoAndPrice struct { @@ -77,7 +85,7 @@ var awsCSV string var azureCSV string func init() { - tflopsMap = make(map[string]*config.GpuInfo, 100) + tflopsMap = make(map[string]*CompleteGPUInfo, 100) } func SetTflopsMapAndInitGPUPricingInfo(ctx context.Context, gpuInfos *[]config.GpuInfo) { @@ -86,8 +94,11 @@ func SetTflopsMapAndInitGPUPricingInfo(ctx context.Context, gpuInfos *[]config.G return } for _, gpuInfo := range *gpuInfos { - tflopsMap[gpuInfo.FullModelName] = &gpuInfo - tflopsMap[gpuInfo.Model] = &gpuInfo + completeInfo := &CompleteGPUInfo{ + GpuInfo: &gpuInfo, + } + tflopsMap[gpuInfo.FullModelName] = completeInfo + tflopsMap[gpuInfo.Model] = completeInfo } initOnce.Do(func() { @@ -151,6 +162,11 @@ func loadCSVInstanceDataFromPath(ctx context.Context, data []byte, provider stri } instanceInfo.FP16TFlopsPerGPU = gpuInfo.Fp16TFlops.AsApproximateFloat64() + // Fill VRAM information if not already set + if gpuInfo.VRAMGigabytes == 0 { + gpuInfo.VRAMGigabytes = instanceInfo.VRAMGigabytesPerGPU + } + instanceInfoAndPrice := GPUNodeInstanceInfoAndPrice{ GPUNodeInstanceInfo: instanceInfo, onDemandPrice: prices[0], @@ -416,3 +432,19 @@ func (p *StaticPricingProvider) GetRegionalGPUNodeInstanceTypes(region string) ( return instanceTypes, len(instanceTypes) > 0 } + +// GetGPUCapacityByModel gets the full capacity (TFlops and VRAM) for a GPU model +// Returns TFlops, VRAM, and whether found +func (p *StaticPricingProvider) GetGPUCapacityByModel(gpuModel string) (resource.Quantity, resource.Quantity, bool) { + <-readyCh + + gpuInfo, exists := tflopsMap[gpuModel] + if !exists { + return resource.Quantity{}, resource.Quantity{}, false + } + + tflops := gpuInfo.Fp16TFlops + vram := *resource.NewQuantity(int64(gpuInfo.VRAMGigabytes)*constants.GiBToBytes, resource.BinarySI) + + return tflops, vram, true +} diff --git a/internal/constants/constants.go b/internal/constants/constants.go index 32b3d6bc..bf95b3d9 100644 --- a/internal/constants/constants.go +++ b/internal/constants/constants.go @@ -69,6 +69,7 @@ const ( GPUModelAnnotation = Domain + "/gpu-model" // GPU ID list is assigned by scheduler, should not specified by user GPUDeviceIDsAnnotation = Domain + "/gpu-ids" + DedicatedGPUAnnotation = Domain + "/dedicated-gpu" SetPendingOwnedWorkloadAnnotation = Domain + "/pending-owned-workload" PricingAnnotation = Domain + "/hourly-pricing" // In remote vGPU mode, selected workload is set by user with /workload annotation or generated by system diff --git a/internal/metrics/recorder.go b/internal/metrics/recorder.go index 9050df00..d01ad315 100644 --- a/internal/metrics/recorder.go +++ b/internal/metrics/recorder.go @@ -187,19 +187,37 @@ func SetPoolMetrics(poolObj *tfv1.GPUPool) { } if poolObj.Status.VirtualAvailableTFlops != nil && poolObj.Status.VirtualAvailableVRAM != nil { - poolMetricsMap[poolObj.Name].AllocatedVramPercentToVirtualCap = poolMetricsMap[poolObj.Name].AllocatedVramBytes / - poolObj.Status.VirtualVRAM.AsApproximateFloat64() * 100 + virtualVRAM := poolObj.Status.VirtualVRAM.AsApproximateFloat64() + virtualTFlops := poolObj.Status.VirtualTFlops.AsApproximateFloat64() - poolMetricsMap[poolObj.Name].AllocatedTflopsPercentToVirtualCap = poolMetricsMap[poolObj.Name].AllocatedTflops / - poolObj.Status.VirtualTFlops.AsApproximateFloat64() * 100 - poolMetricsMap[poolObj.Name].AssignedLimitedTFlops = poolObj.Status.VirtualTFlops.AsApproximateFloat64() - + if virtualVRAM > 0 { + poolMetricsMap[poolObj.Name].AllocatedVramPercentToVirtualCap = poolMetricsMap[poolObj.Name].AllocatedVramBytes / virtualVRAM * 100 + } else { + poolMetricsMap[poolObj.Name].AllocatedVramPercentToVirtualCap = 0 + } + + if virtualTFlops > 0 { + poolMetricsMap[poolObj.Name].AllocatedTflopsPercentToVirtualCap = poolMetricsMap[poolObj.Name].AllocatedTflops / virtualTFlops * 100 + } else { + poolMetricsMap[poolObj.Name].AllocatedTflopsPercentToVirtualCap = 0 + } + + poolMetricsMap[poolObj.Name].AssignedLimitedTFlops = virtualTFlops - poolObj.Status.VirtualAvailableTFlops.AsApproximateFloat64() - poolMetricsMap[poolObj.Name].AssignedLimitedVramBytes = poolObj.Status.VirtualVRAM.AsApproximateFloat64() - + poolMetricsMap[poolObj.Name].AssignedLimitedVramBytes = virtualVRAM - poolObj.Status.VirtualAvailableVRAM.AsApproximateFloat64() - poolMetricsMap[poolObj.Name].AssignedLimitedTFlopsPercentToVirtualCap = poolMetricsMap[poolObj.Name].AssignedLimitedTFlops / - poolObj.Status.VirtualTFlops.AsApproximateFloat64() * 100 - poolMetricsMap[poolObj.Name].AssignedLimitedVramPercentToVirtualCap = poolMetricsMap[poolObj.Name].AssignedLimitedVramBytes / - poolObj.Status.VirtualVRAM.AsApproximateFloat64() * 100 + + if virtualTFlops > 0 { + poolMetricsMap[poolObj.Name].AssignedLimitedTFlopsPercentToVirtualCap = poolMetricsMap[poolObj.Name].AssignedLimitedTFlops / virtualTFlops * 100 + } else { + poolMetricsMap[poolObj.Name].AssignedLimitedTFlopsPercentToVirtualCap = 0 + } + + if virtualVRAM > 0 { + poolMetricsMap[poolObj.Name].AssignedLimitedVramPercentToVirtualCap = poolMetricsMap[poolObj.Name].AssignedLimitedVramBytes / virtualVRAM * 100 + } else { + poolMetricsMap[poolObj.Name].AssignedLimitedVramPercentToVirtualCap = 0 + } } poolMetricsMap[poolObj.Name].GPUCount = int(poolObj.Status.TotalGPUs) } diff --git a/internal/webhook/v1/pod_webhook.go b/internal/webhook/v1/pod_webhook.go index 53610ffe..542a3ab0 100644 --- a/internal/webhook/v1/pod_webhook.go +++ b/internal/webhook/v1/pod_webhook.go @@ -37,6 +37,7 @@ import ( "sigs.k8s.io/controller-runtime/pkg/webhook/admission" tfv1 "github.com/NexusGPU/tensor-fusion/api/v1" + "github.com/NexusGPU/tensor-fusion/internal/cloudprovider/pricing" "github.com/NexusGPU/tensor-fusion/internal/constants" "github.com/NexusGPU/tensor-fusion/internal/portallocator" "github.com/NexusGPU/tensor-fusion/internal/utils" @@ -46,24 +47,26 @@ import ( var httpClient = &http.Client{Timeout: 10 * time.Second} // SetupPodWebhookWithManager registers the webhook for Pod in the manager. -func SetupPodWebhookWithManager(mgr ctrl.Manager, portAllocator *portallocator.PortAllocator) error { +func SetupPodWebhookWithManager(mgr ctrl.Manager, portAllocator *portallocator.PortAllocator, pricingProvider pricing.PricingProvider) error { webhookServer := mgr.GetWebhookServer() webhookServer.Register("/mutate-v1-pod", &admission.Webhook{ Handler: &TensorFusionPodMutator{ - decoder: admission.NewDecoder(runtime.NewScheme()), - Client: mgr.GetClient(), - portAllocator: portAllocator, + decoder: admission.NewDecoder(runtime.NewScheme()), + Client: mgr.GetClient(), + portAllocator: portAllocator, + pricingProvider: pricingProvider, }, }) return nil } type TensorFusionPodMutator struct { - Client client.Client - decoder admission.Decoder - portAllocator *portallocator.PortAllocator + Client client.Client + decoder admission.Decoder + portAllocator *portallocator.PortAllocator + pricingProvider pricing.PricingProvider } // Handle implements admission.Handler interface. @@ -100,7 +103,7 @@ func (m *TensorFusionPodMutator) Handle(ctx context.Context, req admission.Reque return admission.Errored(http.StatusBadRequest, fmt.Errorf("failed to marshal current pod: %w", err)) } - tfInfo, err := ParseTensorFusionInfo(ctx, m.Client, pod) + tfInfo, err := ParseTensorFusionInfo(ctx, m.Client, pod, m.pricingProvider) if err != nil { return admission.Errored(http.StatusInternalServerError, fmt.Errorf("parse tf resources: %w", err)) } diff --git a/internal/webhook/v1/pod_webhook_test.go b/internal/webhook/v1/pod_webhook_test.go index 55f29233..d72770cc 100644 --- a/internal/webhook/v1/pod_webhook_test.go +++ b/internal/webhook/v1/pod_webhook_test.go @@ -23,6 +23,7 @@ import ( "net/http" tfv1 "github.com/NexusGPU/tensor-fusion/api/v1" + "github.com/NexusGPU/tensor-fusion/internal/cloudprovider/pricing" "github.com/NexusGPU/tensor-fusion/internal/config" "github.com/NexusGPU/tensor-fusion/internal/constants" . "github.com/onsi/ginkgo/v2" @@ -532,7 +533,9 @@ var _ = Describe("TensorFusionPodMutator", func() { }, }, } - tfInfo, err := ParseTensorFusionInfo(ctx, k8sClient, pod) + // Create a mock pricing provider for testing + mockPricingProvider := &pricing.StaticPricingProvider{} + tfInfo, err := ParseTensorFusionInfo(ctx, k8sClient, pod, mockPricingProvider) Expect(err).NotTo(HaveOccurred()) Expect(tfInfo.ContainerNames).To(HaveLen(1)) Expect(tfInfo.ContainerNames[0]).To(Equal("test-container")) diff --git a/internal/webhook/v1/tf_parser.go b/internal/webhook/v1/tf_parser.go index bf805b76..cd72fbc1 100644 --- a/internal/webhook/v1/tf_parser.go +++ b/internal/webhook/v1/tf_parser.go @@ -7,6 +7,7 @@ import ( "strings" tfv1 "github.com/NexusGPU/tensor-fusion/api/v1" + "github.com/NexusGPU/tensor-fusion/internal/cloudprovider/pricing" "github.com/NexusGPU/tensor-fusion/internal/constants" "github.com/NexusGPU/tensor-fusion/internal/utils" corev1 "k8s.io/api/core/v1" @@ -29,6 +30,7 @@ func ParseTensorFusionInfo( ctx context.Context, k8sClient client.Client, pod *corev1.Pod, + pricingProvider pricing.PricingProvider, ) (utils.TensorFusionInfo, error) { var info utils.TensorFusionInfo if pod.Annotations == nil { @@ -115,6 +117,12 @@ func ParseTensorFusionInfo( workloadProfile.Spec.GPUModel = gpuModel } + // Handle dedicated GPU logic + err = handleDedicatedGPU(pod, workloadProfile, pricingProvider) + if err != nil { + return info, fmt.Errorf("handle dedicated GPU: %w", err) + } + info.Profile = &workloadProfile.Spec info.ContainerNames = containerNames return info, nil @@ -227,3 +235,30 @@ func setDefaultQuotasIfExists(workloadProfile *tfv1.WorkloadProfile, single tfv1 } } } + +// handleDedicatedGPU handles dedicated GPU annotation by setting full GPU capacity +func handleDedicatedGPU(pod *corev1.Pod, workloadProfile *tfv1.WorkloadProfile, pricingProvider pricing.PricingProvider) error { + dedicatedGPU, ok := pod.Annotations[constants.DedicatedGPUAnnotation] + if !ok || dedicatedGPU != constants.TrueStringValue { + return nil // Not a dedicated GPU request + } + + // Must have GPU model specified for dedicated GPU + if workloadProfile.Spec.GPUModel == "" { + return fmt.Errorf("dedicated GPU requires gpu-model annotation to be specified") + } + + // Get full GPU capacity from pricing provider + tflops, vram, found := pricingProvider.GetGPUCapacityByModel(workloadProfile.Spec.GPUModel) + if !found { + return fmt.Errorf("could not find capacity information for GPU model: %s", workloadProfile.Spec.GPUModel) + } + + // Set full capacity for both requests and limits + workloadProfile.Spec.Resources.Requests.Tflops = tflops + workloadProfile.Spec.Resources.Requests.Vram = vram + workloadProfile.Spec.Resources.Limits.Tflops = tflops + workloadProfile.Spec.Resources.Limits.Vram = vram + + return nil +} diff --git a/internal/webhook/v1/webhook_suite_test.go b/internal/webhook/v1/webhook_suite_test.go index 4e5d369b..26a6685d 100644 --- a/internal/webhook/v1/webhook_suite_test.go +++ b/internal/webhook/v1/webhook_suite_test.go @@ -27,6 +27,7 @@ import ( "time" tfv1 "github.com/NexusGPU/tensor-fusion/api/v1" + "github.com/NexusGPU/tensor-fusion/internal/cloudprovider/pricing" "github.com/NexusGPU/tensor-fusion/internal/config" "github.com/NexusGPU/tensor-fusion/internal/portallocator" . "github.com/onsi/ginkgo/v2" @@ -134,11 +135,13 @@ var _ = BeforeSuite(func() { }) Expect(err).NotTo(HaveOccurred()) + // Create a mock pricing provider for testing + mockPricingProvider := &pricing.StaticPricingProvider{} err = SetupPodWebhookWithManager(mgr, &portallocator.PortAllocator{ PortRangeStartCluster: 42000, PortRangeEndCluster: 62000, BitmapCluster: make([]uint64, (62000-42000)/64+1), - }) + }, mockPricingProvider) Expect(err).NotTo(HaveOccurred()) // +kubebuilder:scaffold:webhook