Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

After calling UpdateTraffic(), make azd to wait until the expected state can be seen from the online endpoint. The deployment updated to 100% traffic should be returned with 100% traffic #3881

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 33 additions & 27 deletions cli/azd/pkg/project/ai_helper.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"github.com/azure/azure-dev/cli/azd/pkg/osutil"
"github.com/azure/azure-dev/cli/azd/pkg/tools"
"github.com/benbjohnson/clock"
"github.com/sethvargo/go-retry"
)

const (
Expand Down Expand Up @@ -72,8 +73,8 @@ type AiHelper interface {
endpointName string,
config *ai.EndpointDeploymentConfig,
) (*armmachinelearning.OnlineDeployment, error)
// DeletePreviousDeployments deletes all previous deployments of an online endpoint except the one with 100% traffic
DeletePreviousDeployments(ctx context.Context, scope *ai.Scope, endpointName string) error
// DeleteDeployments deletes all deployments of an online endpoint except the ones in filter
DeleteDeployments(ctx context.Context, scope *ai.Scope, endpointName string, filter []string) error
// UpdateTraffic updates the traffic distribution of an online endpoint for the specified deployment
UpdateTraffic(
ctx context.Context,
Expand Down Expand Up @@ -470,33 +471,13 @@ func (a *aiHelper) DeployToEndpoint(
return onlineDeployment, nil
}

// DeletePreviousDeployments deletes all previous deployments of an online endpoint except the one with 100% traffic
func (a *aiHelper) DeletePreviousDeployments(
// DeleteDeployments deletes all deployments of an online endpoint except the ones in filter
func (a *aiHelper) DeleteDeployments(
ctx context.Context,
scope *ai.Scope,
endpointName string,
filter []string,
) error {
// Get the endpoint
getEndpointResponse, err := a.endpointsClient.Get(ctx, scope.ResourceGroup(), scope.Workspace(), endpointName, nil)
if err != nil {
return err
}

onlineEndpoint := getEndpointResponse.OnlineEndpoint
var deploymentName string

// Detect the deployment with 100% traffic
for key, trafficWeight := range onlineEndpoint.Properties.Traffic {
if *trafficWeight == 100 {
deploymentName = key
break
}
}

if deploymentName == "" {
return errors.New("no deployment found with 100% traffic")
}

// Get existing deployments
existingDeployments := []*armmachinelearning.OnlineDeployment{}

Expand All @@ -512,8 +493,8 @@ func (a *aiHelper) DeletePreviousDeployments(

// Delete previous deployments
for _, existingDeployment := range existingDeployments {
// Ignore the new deployment
if *existingDeployment.Name == deploymentName {
// Ignore the ones from the filter list
if slices.Contains(filter, *existingDeployment.Name) {
continue
}

Expand Down Expand Up @@ -571,6 +552,31 @@ func (a *aiHelper) UpdateTraffic(
return nil, err
}

// before moving on, we need to validate the state of the online endpoint to be updated with the
vhvb1989 marked this conversation as resolved.
Show resolved Hide resolved
// expected traffic (100%)
err = retry.Do(ctx, retry.WithMaxRetries(3, retry.NewConstant(10*time.Second)),
func(ctx context.Context) error {
getEndpointResponse, err = a.endpointsClient.Get(
ctx, scope.ResourceGroup(), scope.Workspace(), endpointName, nil)

if err != nil {
return retry.RetryableError(err)
}
if getEndpointResponse.OnlineEndpoint.Properties == nil {
return retry.RetryableError(errors.New("online endpoint properties are nil"))
}
// check 100% traffic
for key, trafficWeight := range getEndpointResponse.OnlineEndpoint.Properties.Traffic {
if key == deploymentName && *trafficWeight == 100 {
return nil
}
}
return retry.RetryableError(errors.New("online endpoint traffic is not 100% yet"))
})
if err != nil {
return nil, err
}

return &updateResponse.OnlineEndpoint, nil
}

Expand Down
12 changes: 10 additions & 2 deletions cli/azd/pkg/project/ai_helper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -331,9 +331,17 @@ func Test_AiHelper_UpdateTraffic(t *testing.T) {
trafficMap := map[string]*int32{
deploymentName: convert.RefOf(int32(100)),
}
endpoint := &armmachinelearning.OnlineEndpoint{
Name: convert.RefOf(endpointName),
Properties: &armmachinelearning.OnlineEndpointProperties{
Traffic: map[string]*int32{
deploymentName: convert.RefOf(int32(100)),
},
},
}

mockPythonBridge := &mockPythonBridge{}
mockai.RegisterGetOnlineEndpoint(mockContext, scope.Workspace(), endpointName, http.StatusOK, nil)
mockai.RegisterGetOnlineEndpoint(mockContext, scope.Workspace(), endpointName, http.StatusOK, endpoint)
updateRequest := mockai.RegisterUpdateOnlineEndpoint(mockContext, scope.Workspace(), endpointName, trafficMap)

aiHelper := newAiHelper(t, mockContext, env, mockPythonBridge)
Expand Down Expand Up @@ -386,7 +394,7 @@ func Test_AiHelper_DeletePreviousDeployments(t *testing.T) {
}

aiHelper := newAiHelper(t, mockContext, env, mockPythonBridge)
err := aiHelper.DeletePreviousDeployments(*mockContext.Context, scope, endpointName)
err := aiHelper.DeleteDeployments(*mockContext.Context, scope, endpointName, []string{"MY-DEPLOYMENT"})
require.Len(t, deleteRequests, len(existingDeploymentNames))

require.NoError(t, err)
Expand Down
15 changes: 13 additions & 2 deletions cli/azd/pkg/project/service_target_ai_endpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -161,15 +161,26 @@ func (m *aiEndpointTarget) Deploy(
return
}

if onlineDeployment == nil {
task.SetError(fmt.Errorf("unexpected response from deployToEndpoint: deployment is nil"))
return
}
if onlineDeployment.Name == nil {
task.SetError(fmt.Errorf("unexpected response from deployToEndpoint: deployment name is nil"))
return
}

deploymentName := *onlineDeployment.Name
task.SetProgress(NewServiceProgress("Updating traffic"))
_, err = m.aiHelper.UpdateTraffic(ctx, workspaceScope, endpointName, *onlineDeployment.Name)
_, err = m.aiHelper.UpdateTraffic(ctx, workspaceScope, endpointName, deploymentName)
if err != nil {
task.SetError(fmt.Errorf("failed updating traffic: %w", err))
return
}

task.SetProgress(NewServiceProgress("Removing old deployments"))
if err := m.aiHelper.DeletePreviousDeployments(ctx, workspaceScope, endpointName); err != nil {
if err := m.aiHelper.DeleteDeployments(
ctx, workspaceScope, endpointName, []string{deploymentName}); err != nil {
task.SetError(fmt.Errorf("failed deleting previous deployments: %w", err))
return
}
Expand Down
4 changes: 2 additions & 2 deletions cli/azd/pkg/project/service_target_ai_endpoint_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ func Test_MlEndpointTarget_Deploy(t *testing.T) {
On("UpdateTraffic", *mockContext.Context, scopeType, endpointName, expectedDeploymentName).
Return(onlineEndpoint, nil)
aiHelper.
On("DeletePreviousDeployments", *mockContext.Context, scopeType, endpointName).
On("DeleteDeployments", *mockContext.Context, scopeType, endpointName).
Return(nil)
aiHelper.
On("GetEndpoint", *mockContext.Context, scopeType, endpointName).
Expand Down Expand Up @@ -221,7 +221,7 @@ func (m *mockAiHelper) CreateFlow(
return args.Get(0).(*ai.Flow), args.Error(1)
}

func (m *mockAiHelper) DeletePreviousDeployments(ctx context.Context, scope *ai.Scope, endpointName string) error {
func (m *mockAiHelper) DeleteDeployments(ctx context.Context, scope *ai.Scope, endpointName string, filter []string) error {
args := m.Called(ctx, scope, endpointName)
return args.Error(0)
}
Expand Down
Loading