Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions cmd/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,9 @@ func main() {
// Initialize GPU allocator and set up watches
allocator, portAllocator := startTensorFusionAllocators(ctx, mgr)

startWebhook(mgr, portAllocator)
// Create pricing provider for webhook
pricingProvider := pricing.NewStaticPricingProvider()
startWebhook(mgr, portAllocator, pricingProvider)

scheduler := startScheduler(ctx, allocator, mgr)

Expand Down Expand Up @@ -441,11 +443,15 @@ func startCustomResourceController(
}
}

func startWebhook(mgr manager.Manager, portAllocator *portallocator.PortAllocator) {
func startWebhook(
mgr manager.Manager,
portAllocator *portallocator.PortAllocator,
pricingProvider pricing.PricingProvider,
) {
if os.Getenv(constants.EnableWebhookEnv) == constants.FalseStringValue {
return
}
if err := webhookcorev1.SetupPodWebhookWithManager(mgr, portAllocator); err != nil {
if err := webhookcorev1.SetupPodWebhookWithManager(mgr, portAllocator, pricingProvider); err != nil {
setupLog.Error(err, "unable to create webhook", "webhook", "Pod")
os.Exit(1)
}
Expand Down
44 changes: 38 additions & 6 deletions internal/cloudprovider/pricing/pricing.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import (
"github.com/NexusGPU/tensor-fusion/internal/cloudprovider/types"
"github.com/NexusGPU/tensor-fusion/internal/config"
"github.com/NexusGPU/tensor-fusion/internal/constants"
"k8s.io/apimachinery/pkg/api/resource"
"sigs.k8s.io/controller-runtime/pkg/log"
)

Expand All @@ -39,20 +40,27 @@ const (
providerAzure = "azure"
)

// CompleteGPUInfo combines GpuInfo with VRAM information from instance data
type CompleteGPUInfo struct {
*config.GpuInfo
VRAMGigabytes int32
}

// Global data initialized at package load time
var (
globalAWSGPUInstanceData map[string]GPUNodeInstanceInfoAndPrice
globalAzureGPUInstanceData map[string]GPUNodeInstanceInfoAndPrice
tflopsMap map[string]*config.GpuInfo
tflopsMap map[string]*CompleteGPUInfo
)

var readyCh = make(chan struct{})
var initOnce sync.Once

// PricingProvider provides pricing information and calculations for instance types
type PricingProvider interface {
GetPricing(instanceType, capacityType tfv1.CapacityTypeEnum) (float64, bool)
GetGPUNodeInstanceTypeInfo(region string) ([]string, bool)
GetPricing(instanceType string, capacityType tfv1.CapacityTypeEnum, region string) (float64, bool)
GetRegionalGPUNodeInstanceTypes(region string) ([]types.GPUNodeInstanceInfo, bool)
GetGPUCapacityByModel(gpuModel string) (resource.Quantity, resource.Quantity, bool)
}

type GPUNodeInstanceInfoAndPrice struct {
Expand All @@ -77,7 +85,7 @@ var awsCSV string
var azureCSV string

func init() {
tflopsMap = make(map[string]*config.GpuInfo, 100)
tflopsMap = make(map[string]*CompleteGPUInfo, 100)
}

func SetTflopsMapAndInitGPUPricingInfo(ctx context.Context, gpuInfos *[]config.GpuInfo) {
Expand All @@ -86,8 +94,11 @@ func SetTflopsMapAndInitGPUPricingInfo(ctx context.Context, gpuInfos *[]config.G
return
}
for _, gpuInfo := range *gpuInfos {
tflopsMap[gpuInfo.FullModelName] = &gpuInfo
tflopsMap[gpuInfo.Model] = &gpuInfo
completeInfo := &CompleteGPUInfo{
GpuInfo: &gpuInfo,
}
tflopsMap[gpuInfo.FullModelName] = completeInfo
tflopsMap[gpuInfo.Model] = completeInfo
}

initOnce.Do(func() {
Expand Down Expand Up @@ -151,6 +162,11 @@ func loadCSVInstanceDataFromPath(ctx context.Context, data []byte, provider stri
}
instanceInfo.FP16TFlopsPerGPU = gpuInfo.Fp16TFlops.AsApproximateFloat64()

// Fill VRAM information if not already set
if gpuInfo.VRAMGigabytes == 0 {
gpuInfo.VRAMGigabytes = instanceInfo.VRAMGigabytesPerGPU
}

instanceInfoAndPrice := GPUNodeInstanceInfoAndPrice{
GPUNodeInstanceInfo: instanceInfo,
onDemandPrice: prices[0],
Expand Down Expand Up @@ -416,3 +432,19 @@ func (p *StaticPricingProvider) GetRegionalGPUNodeInstanceTypes(region string) (

return instanceTypes, len(instanceTypes) > 0
}

// GetGPUCapacityByModel gets the full capacity (TFlops and VRAM) for a GPU model
// Returns TFlops, VRAM, and whether found
func (p *StaticPricingProvider) GetGPUCapacityByModel(gpuModel string) (resource.Quantity, resource.Quantity, bool) {
<-readyCh

gpuInfo, exists := tflopsMap[gpuModel]
if !exists {
return resource.Quantity{}, resource.Quantity{}, false
}

tflops := gpuInfo.Fp16TFlops
vram := *resource.NewQuantity(int64(gpuInfo.VRAMGigabytes)*constants.GiBToBytes, resource.BinarySI)

return tflops, vram, true
}
1 change: 1 addition & 0 deletions internal/constants/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ const (
GPUModelAnnotation = Domain + "/gpu-model"
// GPU ID list is assigned by scheduler, should not specified by user
GPUDeviceIDsAnnotation = Domain + "/gpu-ids"
DedicatedGPUAnnotation = Domain + "/dedicated-gpu"
SetPendingOwnedWorkloadAnnotation = Domain + "/pending-owned-workload"
PricingAnnotation = Domain + "/hourly-pricing"
// In remote vGPU mode, selected workload is set by user with /workload annotation or generated by system
Expand Down
38 changes: 28 additions & 10 deletions internal/metrics/recorder.go
Original file line number Diff line number Diff line change
Expand Up @@ -187,19 +187,37 @@ func SetPoolMetrics(poolObj *tfv1.GPUPool) {
}

if poolObj.Status.VirtualAvailableTFlops != nil && poolObj.Status.VirtualAvailableVRAM != nil {
poolMetricsMap[poolObj.Name].AllocatedVramPercentToVirtualCap = poolMetricsMap[poolObj.Name].AllocatedVramBytes /
poolObj.Status.VirtualVRAM.AsApproximateFloat64() * 100
virtualVRAM := poolObj.Status.VirtualVRAM.AsApproximateFloat64()
virtualTFlops := poolObj.Status.VirtualTFlops.AsApproximateFloat64()

poolMetricsMap[poolObj.Name].AllocatedTflopsPercentToVirtualCap = poolMetricsMap[poolObj.Name].AllocatedTflops /
poolObj.Status.VirtualTFlops.AsApproximateFloat64() * 100
poolMetricsMap[poolObj.Name].AssignedLimitedTFlops = poolObj.Status.VirtualTFlops.AsApproximateFloat64() -
if virtualVRAM > 0 {
poolMetricsMap[poolObj.Name].AllocatedVramPercentToVirtualCap = poolMetricsMap[poolObj.Name].AllocatedVramBytes / virtualVRAM * 100
} else {
poolMetricsMap[poolObj.Name].AllocatedVramPercentToVirtualCap = 0
}

if virtualTFlops > 0 {
poolMetricsMap[poolObj.Name].AllocatedTflopsPercentToVirtualCap = poolMetricsMap[poolObj.Name].AllocatedTflops / virtualTFlops * 100
} else {
poolMetricsMap[poolObj.Name].AllocatedTflopsPercentToVirtualCap = 0
}

poolMetricsMap[poolObj.Name].AssignedLimitedTFlops = virtualTFlops -
poolObj.Status.VirtualAvailableTFlops.AsApproximateFloat64()
poolMetricsMap[poolObj.Name].AssignedLimitedVramBytes = poolObj.Status.VirtualVRAM.AsApproximateFloat64() -
poolMetricsMap[poolObj.Name].AssignedLimitedVramBytes = virtualVRAM -
poolObj.Status.VirtualAvailableVRAM.AsApproximateFloat64()
poolMetricsMap[poolObj.Name].AssignedLimitedTFlopsPercentToVirtualCap = poolMetricsMap[poolObj.Name].AssignedLimitedTFlops /
poolObj.Status.VirtualTFlops.AsApproximateFloat64() * 100
poolMetricsMap[poolObj.Name].AssignedLimitedVramPercentToVirtualCap = poolMetricsMap[poolObj.Name].AssignedLimitedVramBytes /
poolObj.Status.VirtualVRAM.AsApproximateFloat64() * 100

if virtualTFlops > 0 {
poolMetricsMap[poolObj.Name].AssignedLimitedTFlopsPercentToVirtualCap = poolMetricsMap[poolObj.Name].AssignedLimitedTFlops / virtualTFlops * 100
} else {
poolMetricsMap[poolObj.Name].AssignedLimitedTFlopsPercentToVirtualCap = 0
}

if virtualVRAM > 0 {
poolMetricsMap[poolObj.Name].AssignedLimitedVramPercentToVirtualCap = poolMetricsMap[poolObj.Name].AssignedLimitedVramBytes / virtualVRAM * 100
} else {
poolMetricsMap[poolObj.Name].AssignedLimitedVramPercentToVirtualCap = 0
}
}
poolMetricsMap[poolObj.Name].GPUCount = int(poolObj.Status.TotalGPUs)
}
Expand Down
19 changes: 11 additions & 8 deletions internal/webhook/v1/pod_webhook.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ import (
"sigs.k8s.io/controller-runtime/pkg/webhook/admission"

tfv1 "github.com/NexusGPU/tensor-fusion/api/v1"
"github.com/NexusGPU/tensor-fusion/internal/cloudprovider/pricing"
"github.com/NexusGPU/tensor-fusion/internal/constants"
"github.com/NexusGPU/tensor-fusion/internal/portallocator"
"github.com/NexusGPU/tensor-fusion/internal/utils"
Expand All @@ -46,24 +47,26 @@ import (
var httpClient = &http.Client{Timeout: 10 * time.Second}

// SetupPodWebhookWithManager registers the webhook for Pod in the manager.
func SetupPodWebhookWithManager(mgr ctrl.Manager, portAllocator *portallocator.PortAllocator) error {
func SetupPodWebhookWithManager(mgr ctrl.Manager, portAllocator *portallocator.PortAllocator, pricingProvider pricing.PricingProvider) error {
webhookServer := mgr.GetWebhookServer()

webhookServer.Register("/mutate-v1-pod",
&admission.Webhook{
Handler: &TensorFusionPodMutator{
decoder: admission.NewDecoder(runtime.NewScheme()),
Client: mgr.GetClient(),
portAllocator: portAllocator,
decoder: admission.NewDecoder(runtime.NewScheme()),
Client: mgr.GetClient(),
portAllocator: portAllocator,
pricingProvider: pricingProvider,
},
})
return nil
}

type TensorFusionPodMutator struct {
Client client.Client
decoder admission.Decoder
portAllocator *portallocator.PortAllocator
Client client.Client
decoder admission.Decoder
portAllocator *portallocator.PortAllocator
pricingProvider pricing.PricingProvider
}

// Handle implements admission.Handler interface.
Expand Down Expand Up @@ -100,7 +103,7 @@ func (m *TensorFusionPodMutator) Handle(ctx context.Context, req admission.Reque
return admission.Errored(http.StatusBadRequest, fmt.Errorf("failed to marshal current pod: %w", err))
}

tfInfo, err := ParseTensorFusionInfo(ctx, m.Client, pod)
tfInfo, err := ParseTensorFusionInfo(ctx, m.Client, pod, m.pricingProvider)
if err != nil {
return admission.Errored(http.StatusInternalServerError, fmt.Errorf("parse tf resources: %w", err))
}
Expand Down
5 changes: 4 additions & 1 deletion internal/webhook/v1/pod_webhook_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"net/http"

tfv1 "github.com/NexusGPU/tensor-fusion/api/v1"
"github.com/NexusGPU/tensor-fusion/internal/cloudprovider/pricing"
"github.com/NexusGPU/tensor-fusion/internal/config"
"github.com/NexusGPU/tensor-fusion/internal/constants"
. "github.com/onsi/ginkgo/v2"
Expand Down Expand Up @@ -532,7 +533,9 @@ var _ = Describe("TensorFusionPodMutator", func() {
},
},
}
tfInfo, err := ParseTensorFusionInfo(ctx, k8sClient, pod)
// Create a mock pricing provider for testing
mockPricingProvider := &pricing.StaticPricingProvider{}
tfInfo, err := ParseTensorFusionInfo(ctx, k8sClient, pod, mockPricingProvider)
Expect(err).NotTo(HaveOccurred())
Expect(tfInfo.ContainerNames).To(HaveLen(1))
Expect(tfInfo.ContainerNames[0]).To(Equal("test-container"))
Expand Down
35 changes: 35 additions & 0 deletions internal/webhook/v1/tf_parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"strings"

tfv1 "github.com/NexusGPU/tensor-fusion/api/v1"
"github.com/NexusGPU/tensor-fusion/internal/cloudprovider/pricing"
"github.com/NexusGPU/tensor-fusion/internal/constants"
"github.com/NexusGPU/tensor-fusion/internal/utils"
corev1 "k8s.io/api/core/v1"
Expand All @@ -29,6 +30,7 @@ func ParseTensorFusionInfo(
ctx context.Context,
k8sClient client.Client,
pod *corev1.Pod,
pricingProvider pricing.PricingProvider,
) (utils.TensorFusionInfo, error) {
var info utils.TensorFusionInfo
if pod.Annotations == nil {
Expand Down Expand Up @@ -115,6 +117,12 @@ func ParseTensorFusionInfo(
workloadProfile.Spec.GPUModel = gpuModel
}

// Handle dedicated GPU logic
err = handleDedicatedGPU(pod, workloadProfile, pricingProvider)
if err != nil {
return info, fmt.Errorf("handle dedicated GPU: %w", err)
}

info.Profile = &workloadProfile.Spec
info.ContainerNames = containerNames
return info, nil
Expand Down Expand Up @@ -227,3 +235,30 @@ func setDefaultQuotasIfExists(workloadProfile *tfv1.WorkloadProfile, single tfv1
}
}
}

// handleDedicatedGPU handles dedicated GPU annotation by setting full GPU capacity
func handleDedicatedGPU(pod *corev1.Pod, workloadProfile *tfv1.WorkloadProfile, pricingProvider pricing.PricingProvider) error {
dedicatedGPU, ok := pod.Annotations[constants.DedicatedGPUAnnotation]
if !ok || dedicatedGPU != constants.TrueStringValue {
return nil // Not a dedicated GPU request
}

// Must have GPU model specified for dedicated GPU
if workloadProfile.Spec.GPUModel == "" {
return fmt.Errorf("dedicated GPU requires gpu-model annotation to be specified")
}

// Get full GPU capacity from pricing provider
tflops, vram, found := pricingProvider.GetGPUCapacityByModel(workloadProfile.Spec.GPUModel)
if !found {
return fmt.Errorf("could not find capacity information for GPU model: %s", workloadProfile.Spec.GPUModel)
}

// Set full capacity for both requests and limits
workloadProfile.Spec.Resources.Requests.Tflops = tflops
workloadProfile.Spec.Resources.Requests.Vram = vram
workloadProfile.Spec.Resources.Limits.Tflops = tflops
workloadProfile.Spec.Resources.Limits.Vram = vram

return nil
}
5 changes: 4 additions & 1 deletion internal/webhook/v1/webhook_suite_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import (
"time"

tfv1 "github.com/NexusGPU/tensor-fusion/api/v1"
"github.com/NexusGPU/tensor-fusion/internal/cloudprovider/pricing"
"github.com/NexusGPU/tensor-fusion/internal/config"
"github.com/NexusGPU/tensor-fusion/internal/portallocator"
. "github.com/onsi/ginkgo/v2"
Expand Down Expand Up @@ -134,11 +135,13 @@ var _ = BeforeSuite(func() {
})
Expect(err).NotTo(HaveOccurred())

// Create a mock pricing provider for testing
mockPricingProvider := &pricing.StaticPricingProvider{}
err = SetupPodWebhookWithManager(mgr, &portallocator.PortAllocator{
PortRangeStartCluster: 42000,
PortRangeEndCluster: 62000,
BitmapCluster: make([]uint64, (62000-42000)/64+1),
})
}, mockPricingProvider)
Expect(err).NotTo(HaveOccurred())

// +kubebuilder:scaffold:webhook
Expand Down