Skip to content

Commit

Permalink
accept the correct input for executor (#1506)
Browse files Browse the repository at this point in the history
* accept the correct input for  executor

* fix lint

* add tests

* add tests

* add tests

* fix lint
  • Loading branch information
sunkickr committed Feb 2, 2024
1 parent 0d2e17f commit fe7c3fe
Show file tree
Hide file tree
Showing 6 changed files with 91 additions and 26 deletions.
58 changes: 37 additions & 21 deletions cloud/deployment/deployment.go
Original file line number Diff line number Diff line change
Expand Up @@ -323,9 +323,10 @@ func Create(name, workspaceID, description, clusterID, runtimeVersion, dagDeploy
requestedCloudProvider = astroplatformcore.CreateStandardDeploymentRequestCloudProviderAZURE
}
var requestedExecutor astroplatformcore.CreateStandardDeploymentRequestExecutor
if executor == CeleryExecutor {
if strings.EqualFold(executor, CeleryExecutor) || strings.EqualFold(executor, CELERY) {
requestedExecutor = astroplatformcore.CreateStandardDeploymentRequestExecutorCELERY
} else if executor == KubeExecutor {
}
if strings.EqualFold(executor, KubeExecutor) || strings.EqualFold(executor, KUBERNETES) {
requestedExecutor = astroplatformcore.CreateStandardDeploymentRequestExecutorKUBERNETES
}
standardDeploymentRequest := astroplatformcore.CreateStandardDeploymentRequest{
Expand All @@ -345,7 +346,7 @@ func Create(name, workspaceID, description, clusterID, runtimeVersion, dagDeploy
ResourceQuotaCpu: resourceQuotaCpu,
ResourceQuotaMemory: resourceQuotaMemory,
}
if executor == CeleryExecutor {
if strings.EqualFold(executor, CeleryExecutor) || strings.EqualFold(executor, CELERY) {
standardDeploymentRequest.WorkerQueues = &defautWorkerQueue
}
switch schedulerSize {
Expand All @@ -369,11 +370,13 @@ func Create(name, workspaceID, description, clusterID, runtimeVersion, dagDeploy
// build dedicated input
if IsDeploymentDedicated(deploymentType) {
var requestedExecutor astroplatformcore.CreateDedicatedDeploymentRequestExecutor
if executor == CeleryExecutor {
if strings.EqualFold(executor, CeleryExecutor) || strings.EqualFold(executor, CELERY) {
requestedExecutor = astroplatformcore.CreateDedicatedDeploymentRequestExecutorCELERY
fmt.Println(requestedExecutor)
}
if executor == KubeExecutor {
if strings.EqualFold(executor, KubeExecutor) || strings.EqualFold(executor, KUBERNETES) {
requestedExecutor = astroplatformcore.CreateDedicatedDeploymentRequestExecutorKUBERNETES
fmt.Println(requestedExecutor)
}
dedicatedDeploymentRequest := astroplatformcore.CreateDedicatedDeploymentRequest{
AstroRuntimeVersion: runtimeVersion,
Expand All @@ -391,10 +394,10 @@ func Create(name, workspaceID, description, clusterID, runtimeVersion, dagDeploy
ResourceQuotaCpu: resourceQuotaCpu,
ResourceQuotaMemory: resourceQuotaMemory,
}
if executor == CeleryExecutor {
if strings.EqualFold(executor, CeleryExecutor) || strings.EqualFold(executor, CELERY) {
dedicatedDeploymentRequest.WorkerQueues = &defautWorkerQueue
}

fmt.Println(dedicatedDeploymentRequest.Executor)
switch schedulerSize {
case SmallScheduler:
dedicatedDeploymentRequest.SchedulerSize = astroplatformcore.CreateDedicatedDeploymentRequestSchedulerSizeSMALL
Expand Down Expand Up @@ -440,9 +443,10 @@ func Create(name, workspaceID, description, clusterID, runtimeVersion, dagDeploy
return ErrInvalidResourceRequest
}
var requestedExecutor astroplatformcore.CreateHybridDeploymentRequestExecutor
if executor == CeleryExecutor {
if strings.EqualFold(executor, CeleryExecutor) || strings.EqualFold(executor, CELERY) {
requestedExecutor = astroplatformcore.CreateHybridDeploymentRequestExecutorCELERY
} else if executor == KubeExecutor {
}
if strings.EqualFold(executor, KubeExecutor) || strings.EqualFold(executor, KUBERNETES) {
requestedExecutor = astroplatformcore.CreateHybridDeploymentRequestExecutorKUBERNETES
}
hybridDeploymentRequest := astroplatformcore.CreateHybridDeploymentRequest{
Expand All @@ -461,7 +465,7 @@ func Create(name, workspaceID, description, clusterID, runtimeVersion, dagDeploy
},
Type: astroplatformcore.CreateHybridDeploymentRequestTypeHYBRID,
}
if executor == CeleryExecutor {
if strings.EqualFold(executor, CeleryExecutor) || strings.EqualFold(executor, CELERY) {
hybridDeploymentRequest.WorkerQueues = &defautWorkerQueue
} else {
hybridDeploymentRequest.TaskPodNodePoolId = &nodePools[0].Id
Expand Down Expand Up @@ -868,12 +872,16 @@ func Update(deploymentID, name, ws, description, deploymentName, dagDeploy, exec
}
if IsDeploymentStandard(*currentDeployment.Type) {
var requestedExecutor astroplatformcore.UpdateStandardDeploymentRequestExecutor
switch executor {
switch strings.ToUpper(executor) {
case "":
requestedExecutor = astroplatformcore.UpdateStandardDeploymentRequestExecutor(*currentDeployment.Executor)
case CeleryExecutor:
case strings.ToUpper(CeleryExecutor):
requestedExecutor = astroplatformcore.UpdateStandardDeploymentRequestExecutorCELERY
case strings.ToUpper(KubeExecutor):
requestedExecutor = astroplatformcore.UpdateStandardDeploymentRequestExecutorKUBERNETES
case strings.ToUpper(CELERY):
requestedExecutor = astroplatformcore.UpdateStandardDeploymentRequestExecutorCELERY
case KubeExecutor:
case strings.ToUpper(KUBERNETES):
requestedExecutor = astroplatformcore.UpdateStandardDeploymentRequestExecutorKUBERNETES
}

Expand Down Expand Up @@ -924,12 +932,16 @@ func Update(deploymentID, name, ws, description, deploymentName, dagDeploy, exec
}
if IsDeploymentDedicated(*currentDeployment.Type) {
var requestedExecutor astroplatformcore.UpdateDedicatedDeploymentRequestExecutor
switch executor {
switch strings.ToUpper(executor) {
case "":
requestedExecutor = astroplatformcore.UpdateDedicatedDeploymentRequestExecutor(*currentDeployment.Executor)
case CeleryExecutor:
case strings.ToUpper(CeleryExecutor):
requestedExecutor = astroplatformcore.UpdateDedicatedDeploymentRequestExecutorCELERY
case KubeExecutor:
case strings.ToUpper(KubeExecutor):
requestedExecutor = astroplatformcore.UpdateDedicatedDeploymentRequestExecutorKUBERNETES
case strings.ToUpper(CELERY):
requestedExecutor = astroplatformcore.UpdateDedicatedDeploymentRequestExecutorCELERY
case strings.ToUpper(KUBERNETES):
requestedExecutor = astroplatformcore.UpdateDedicatedDeploymentRequestExecutorKUBERNETES
}
dedicatedDeploymentRequest = astroplatformcore.UpdateDedicatedDeploymentRequest{
Expand Down Expand Up @@ -1013,12 +1025,16 @@ func Update(deploymentID, name, ws, description, deploymentName, dagDeploy, exec
return ErrInvalidResourceRequest
}
var requestedExecutor astroplatformcore.UpdateHybridDeploymentRequestExecutor
switch executor {
switch strings.ToUpper(executor) {
case "":
requestedExecutor = astroplatformcore.UpdateHybridDeploymentRequestExecutor(*currentDeployment.Executor)
case CeleryExecutor:
case strings.ToUpper(CeleryExecutor):
requestedExecutor = astroplatformcore.UpdateHybridDeploymentRequestExecutorCELERY
case strings.ToUpper(KubeExecutor):
requestedExecutor = astroplatformcore.UpdateHybridDeploymentRequestExecutorKUBERNETES
case strings.ToUpper(CELERY):
requestedExecutor = astroplatformcore.UpdateHybridDeploymentRequestExecutorCELERY
case KubeExecutor:
case strings.ToUpper(KUBERNETES):
requestedExecutor = astroplatformcore.UpdateHybridDeploymentRequestExecutorKUBERNETES
}
hybridDeploymentRequest := astroplatformcore.UpdateHybridDeploymentRequest{
Expand Down Expand Up @@ -1534,14 +1550,14 @@ func GetDeploymentURL(deploymentID, workspaceID string) (string, error) {
// It returns true if a warning was printed and false if not.
func printWarning(executor string, existingQLength int) bool {
var printed bool
if executor == KubeExecutor {
if strings.EqualFold(executor, KubeExecutor) || strings.EqualFold(executor, KUBERNETES) {
if existingQLength > 1 {
fmt.Println("\n Switching to KubernetesExecutor will replace all existing worker queues " +
"with one new default worker queue for this deployment.")
printed = true
}
} else {
if executor == CeleryExecutor {
if strings.EqualFold(executor, CeleryExecutor) || strings.EqualFold(executor, CELERY) {
fmt.Println("\n Switching to CeleryExecutor will replace the existing worker queue " +
"with a new default worker queue for this deployment.")
printed = true
Expand Down
6 changes: 3 additions & 3 deletions cloud/deployment/deployment_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -962,22 +962,22 @@ func TestCreate(t *testing.T) {

// Mock user input for deployment name
defer testUtil.MockUserInput(t, "test-name")()

// Call the Create function with deployment type as STANDARD, cloud provider, and region set
err := Create("", ws, "test-desc", csID, "4.2.5", dagDeploy, CeleryExecutor, "aws", "us-west-2", SmallScheduler, "", "", "", "", "", "", astroplatformcore.DeploymentTypeSTANDARD, 0, 0, mockPlatformCoreClient, mockCoreClient, false)
assert.NoError(t, err)

err = Create("test-name", ws, "test-desc", csID, "4.2.5", dagDeploy, CeleryExecutor, azureCloud, "us-west-2", MediumScheduler, "", "", "", "", "", "", astroplatformcore.DeploymentTypeSTANDARD, 0, 0, mockPlatformCoreClient, mockCoreClient, false)
assert.NoError(t, err)

err = Create("test-name", ws, "test-desc", csID, "4.2.5", dagDeploy, KubeExecutor, gcpCloud, "us-west-2", LargeScheduler, "enable", "", "", "", "", "", astroplatformcore.DeploymentTypeSTANDARD, 0, 0, mockPlatformCoreClient, mockCoreClient, false)
err = Create("test-name", ws, "test-desc", csID, "4.2.5", dagDeploy, KUBERNETES, gcpCloud,
"us-west-2", LargeScheduler, "enable", "", "", "", "", "", astroplatformcore.DeploymentTypeSTANDARD, 0, 0, mockPlatformCoreClient, mockCoreClient, false)
assert.NoError(t, err)

// Call the Create function with deployment type as DEDICATED, cloud provider, and region set
err = Create("test-name", ws, "test-desc", csID, "4.2.5", dagDeploy, CeleryExecutor, "aws", "us-west-2", SmallScheduler, "", "", "", "", "", "", astroplatformcore.DeploymentTypeDEDICATED, 0, 0, mockPlatformCoreClient, mockCoreClient, false)
assert.NoError(t, err)

err = Create("test-name", ws, "test-desc", csID, "4.2.5", dagDeploy, CeleryExecutor, azureCloud, "us-west-2", MediumScheduler, "", "", "", "", "", "", astroplatformcore.DeploymentTypeDEDICATED, 0, 0, mockPlatformCoreClient, mockCoreClient, false)
err = Create("test-name", ws, "test-desc", csID, "4.2.5", dagDeploy, CELERY, azureCloud, "us-west-2", MediumScheduler, "", "", "", "", "", "", astroplatformcore.DeploymentTypeDEDICATED, 0, 0, mockPlatformCoreClient, mockCoreClient, false)
assert.NoError(t, err)

err = Create("test-name", ws, "test-desc", csID, "4.2.5", dagDeploy, KubeExecutor, gcpCloud, "us-west-2", LargeScheduler, "enable", "", "", "", "", "", astroplatformcore.DeploymentTypeDEDICATED, 0, 0, mockPlatformCoreClient, mockCoreClient, false)
Expand Down
2 changes: 1 addition & 1 deletion cloud/deployment/fromfile/fromfile.go
Original file line number Diff line number Diff line change
Expand Up @@ -853,7 +853,7 @@ func createEnvVarsRequest(deploymentFromFile *inspect.FormattedDeployment) (envV

// isValidExecutor returns true for valid executor values and false if not.
func isValidExecutor(executor string) bool {
return executor == deployment.CeleryExecutor || executor == deployment.KubeExecutor || executor == deployment.CELERY || executor == deployment.KUBERNETES
return strings.EqualFold(executor, deployment.KubeExecutor) || strings.EqualFold(executor, deployment.CeleryExecutor) || strings.EqualFold(executor, deployment.CELERY) || strings.EqualFold(executor, deployment.KUBERNETES)
}

func transformDeploymentType(deploymentType string) astroplatformcore.DeploymentType {
Expand Down
24 changes: 24 additions & 0 deletions cloud/deployment/fromfile/fromfile_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3685,6 +3685,30 @@ func TestIsValidExecutor(t *testing.T) {
actual := isValidExecutor(deployment.KubeExecutor)
assert.True(t, actual)
})
t.Run("returns true if executor is CELERY", func(t *testing.T) {
actual := isValidExecutor(deployment.CELERY)
assert.True(t, actual)
})
t.Run("returns true if executor is KUBERNETES", func(t *testing.T) {
actual := isValidExecutor(deployment.KUBERNETES)
assert.True(t, actual)
})
t.Run("returns true if executor is celery", func(t *testing.T) {
actual := isValidExecutor("celery")
assert.True(t, actual)
})
t.Run("returns true if executor is kubernetes", func(t *testing.T) {
actual := isValidExecutor("kubernetes")
assert.True(t, actual)
})
t.Run("returns true if executor is celery", func(t *testing.T) {
actual := isValidExecutor("celeryexecutor")
assert.True(t, actual)
})
t.Run("returns true if executor is kubernetes", func(t *testing.T) {
actual := isValidExecutor("kubernetesexecutor")
assert.True(t, actual)
})
t.Run("returns false if executor is neither Celery nor Kubernetes", func(t *testing.T) {
actual := isValidExecutor("test-executor")
assert.False(t, actual)
Expand Down
3 changes: 2 additions & 1 deletion cmd/cloud/deployment.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package cloud
import (
"fmt"
"io"
"strings"

airflowversions "github.com/astronomer/astro-cli/airflow_versions"
astrocore "github.com/astronomer/astro-cli/astro-client-core"
Expand Down Expand Up @@ -549,7 +550,7 @@ func deploymentVariableUpdate(cmd *cobra.Command, args []string, out io.Writer)
}

func isValidExecutor(executor string) bool {
return executor == deployment.KubeExecutor || executor == deployment.CeleryExecutor || executor == ""
return strings.EqualFold(executor, deployment.KubeExecutor) || strings.EqualFold(executor, deployment.CeleryExecutor) || executor == "" || strings.EqualFold(executor, deployment.CELERY) || strings.EqualFold(executor, deployment.KUBERNETES)
}

// isValidCloudProvider returns true for valid CloudProvider values and false if not.
Expand Down
24 changes: 24 additions & 0 deletions cmd/cloud/deployment_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -999,6 +999,30 @@ func TestIsValidExecutor(t *testing.T) {
actual := isValidExecutor(deployment.CeleryExecutor)
assert.True(t, actual)
})
t.Run("returns true if executor is CELERY", func(t *testing.T) {
actual := isValidExecutor(deployment.CELERY)
assert.True(t, actual)
})
t.Run("returns true if executor is KUBERNETES", func(t *testing.T) {
actual := isValidExecutor(deployment.KUBERNETES)
assert.True(t, actual)
})
t.Run("returns true if executor is celery", func(t *testing.T) {
actual := isValidExecutor("celery")
assert.True(t, actual)
})
t.Run("returns true if executor is kubernetes", func(t *testing.T) {
actual := isValidExecutor("kubernetes")
assert.True(t, actual)
})
t.Run("returns true if executor is celery", func(t *testing.T) {
actual := isValidExecutor("celeryexecutor")
assert.True(t, actual)
})
t.Run("returns true if executor is kubernetes", func(t *testing.T) {
actual := isValidExecutor("kubernetesexecutor")
assert.True(t, actual)
})
t.Run("returns true when no Executor is requested", func(t *testing.T) {
actual := isValidExecutor("")
assert.True(t, actual)
Expand Down

0 comments on commit fe7c3fe

Please sign in to comment.