diff --git a/pkg/providers/aws/provider.go b/pkg/providers/aws/provider.go index 256803a9..241ffcd9 100644 --- a/pkg/providers/aws/provider.go +++ b/pkg/providers/aws/provider.go @@ -39,6 +39,7 @@ const NAME = "aws" type baseProvider struct { clientFactory ClientFactory + trimTiers int } type EC2Client interface { @@ -71,9 +72,14 @@ func NamedLoader() (string, providers.Loader) { } func Loader(ctx context.Context, cfg providers.Config) (providers.Provider, *httperr.Error) { - creds, err := getCredentials(ctx, cfg.Creds) + creds, httpErr := getCredentials(ctx, cfg.Creds) + if httpErr != nil { + return nil, httpErr + } + + trimTiers, err := providers.GetTrimTiers(cfg.Params) if err != nil { - return nil, err + return nil, httperr.NewError(http.StatusBadRequest, "parameters error: "+err.Error()) } clientFactory := func(region string, pageSize *int) (*Client, error) { @@ -96,7 +102,7 @@ func Loader(ctx context.Context, cfg providers.Config) (providers.Provider, *htt }, nil } - return New(clientFactory), nil + return New(clientFactory, trimTiers), nil } func getCredentials(ctx context.Context, creds map[string]any) (*Credentials, *httperr.Error) { @@ -167,16 +173,19 @@ func (p *baseProvider) GenerateTopologyConfig(ctx context.Context, pageSize *int klog.Infof("Extracted topology for %d instances", topo.Len()) - return topo.ToThreeTierGraph(NAME, instances, false), nil + return topo.ToThreeTierGraph(NAME, instances, p.trimTiers, false), nil } type Provider struct { baseProvider } -func New(clientFactory ClientFactory) *Provider { +func New(clientFactory ClientFactory, trimTiers int) *Provider { return &Provider{ - baseProvider: baseProvider{clientFactory: clientFactory}, + baseProvider: baseProvider{ + clientFactory: clientFactory, + trimTiers: trimTiers, + }, } } diff --git a/pkg/providers/aws/provider_sim.go b/pkg/providers/aws/provider_sim.go index 20a116c6..2f936615 100644 --- a/pkg/providers/aws/provider_sim.go +++ b/pkg/providers/aws/provider_sim.go @@ -183,16 +183,19 @@ func LoaderSim(ctx context.Context, cfg providers.Config) (providers.Provider, * }, nil } - return NewSim(clientFactory), nil + return NewSim(clientFactory, p.TrimTiers), nil } type simProvider struct { baseProvider } -func NewSim(clientFactory ClientFactory) *simProvider { +func NewSim(clientFactory ClientFactory, trimTiers int) *simProvider { return &simProvider{ - baseProvider: baseProvider{clientFactory: clientFactory}, + baseProvider: baseProvider{ + clientFactory: clientFactory, + trimTiers: trimTiers, + }, } } diff --git a/pkg/providers/aws/provider_test.go b/pkg/providers/aws/provider_test.go index 5a3c30d8..13354321 100644 --- a/pkg/providers/aws/provider_test.go +++ b/pkg/providers/aws/provider_test.go @@ -46,12 +46,12 @@ func TestGetCredentials(t *testing.T) { { name: "Case 3: invalid secretAccessKey", creds: map[string]any{"accessKeyId": "id", "secretAccessKey": false}, - err: "credentials error: 'secretAccessKey' must be a string", + err: "credentials error: 'secretAccessKey' must be of type string", }, { name: "Case 4: invalid token", creds: map[string]any{"accessKeyId": "id", "secretAccessKey": "secret", "token": false}, - err: "credentials error: 'token' must be a string", + err: "credentials error: 'token' must be of type string", }, { name: "Case 5: valid provided creds", diff --git a/pkg/providers/gcp/provider.go b/pkg/providers/gcp/provider.go index 2cd72ab6..8d83e85e 100644 --- a/pkg/providers/gcp/provider.go +++ b/pkg/providers/gcp/provider.go @@ -39,6 +39,7 @@ const NAME = "gcp" type baseProvider struct { clientFactory ClientFactory + trimTiers int } type ClientFactory func(pageSize *int) (Client, error) @@ -83,6 +84,10 @@ func Loader(ctx context.Context, config providers.Config) (providers.Provider, * if err != nil { return nil, httperr.NewError(http.StatusBadRequest, err.Error()) } + trimTiers, err := providers.GetTrimTiers(config.Params) + if err != nil { + return nil, httperr.NewError(http.StatusBadRequest, "parameters error: "+err.Error()) + } clientFactory := func(pageSize *int) (Client, error) { instanceClient, err := compute.NewInstancesRESTClient(ctx) if err != nil { @@ -96,7 +101,7 @@ func Loader(ctx context.Context, config providers.Config) (providers.Provider, * }, nil } - return New(clientFactory), nil + return New(clientFactory, trimTiers), nil } func getProjectID(ctx context.Context, params map[string]any) (string, error) { @@ -142,16 +147,19 @@ func (p *baseProvider) GenerateTopologyConfig(ctx context.Context, pageSize *int return nil, err } - return topo.ToThreeTierGraph(NAME, instances, false), nil + return topo.ToThreeTierGraph(NAME, instances, p.trimTiers, false), nil } type Provider struct { baseProvider } -func New(clientFactory ClientFactory) *Provider { +func New(clientFactory ClientFactory, trimTiers int) *Provider { return &Provider{ - baseProvider: baseProvider{clientFactory: clientFactory}, + baseProvider: baseProvider{ + clientFactory: clientFactory, + trimTiers: trimTiers, + }, } } diff --git a/pkg/providers/gcp/provider_sim.go b/pkg/providers/gcp/provider_sim.go index 3fbc37b2..5e3f032a 100644 --- a/pkg/providers/gcp/provider_sim.go +++ b/pkg/providers/gcp/provider_sim.go @@ -160,16 +160,19 @@ func LoaderSim(_ context.Context, cfg providers.Config) (providers.Provider, *ht }, nil } - return NewSim(clientFactory), nil + return NewSim(clientFactory, p.TrimTiers), nil } type simProvider struct { baseProvider } -func NewSim(clientFactory ClientFactory) *simProvider { +func NewSim(clientFactory ClientFactory, trimTiers int) *simProvider { return &simProvider{ - baseProvider: baseProvider{clientFactory: clientFactory}, + baseProvider: baseProvider{ + clientFactory: clientFactory, + trimTiers: trimTiers, + }, } } diff --git a/pkg/providers/gcp/provider_test.go b/pkg/providers/gcp/provider_test.go index 0d501113..b1de60cb 100644 --- a/pkg/providers/gcp/provider_test.go +++ b/pkg/providers/gcp/provider_test.go @@ -39,12 +39,12 @@ func TestGetProjectID(t *testing.T) { name: "Case 3: invalid project_id in params", params: map[string]any{"project_id": false}, content: []byte(`{"project_id": "test-project"}`), - err: "error in topology request parameters: 'project_id' must be a string", + err: "error in topology request parameters: 'project_id' must be of type string", }, { name: "Case 4: invalid project_id in cert keys", content: []byte(`{"project_id": false}`), - err: "error in GOOGLE_APPLICATION_CREDENTIALS: 'project_id' must be a string", + err: "error in GOOGLE_APPLICATION_CREDENTIALS: 'project_id' must be of type string", }, { name: "Case 5: invalid credentials file path", diff --git a/pkg/providers/lambdai/provider.go b/pkg/providers/lambdai/provider.go index c52980c4..0e7183aa 100644 --- a/pkg/providers/lambdai/provider.go +++ b/pkg/providers/lambdai/provider.go @@ -37,6 +37,7 @@ type ClientFactory func(pageSize *int) (Client, error) type baseProvider struct { clientFactory ClientFactory + trimTiers int } // lambdaiClient is a Topology API client. @@ -123,6 +124,10 @@ func Loader(ctx context.Context, config providers.Config) (providers.Provider, * if err != nil { return nil, httperr.NewError(http.StatusBadRequest, "parameters error: "+err.Error()) } + trimTiers, err := providers.GetTrimTiers(config.Params) + if err != nil { + return nil, httperr.NewError(http.StatusBadRequest, "parameters error: "+err.Error()) + } clientFactory := func(pageSize *int) (Client, error) { return &lambdaiClient{ @@ -133,7 +138,7 @@ func Loader(ctx context.Context, config providers.Config) (providers.Provider, * }, nil } - return New(clientFactory), nil + return New(clientFactory, trimTiers), nil } func getPageSize(sz *int) int { @@ -149,15 +154,18 @@ func (p *baseProvider) GenerateTopologyConfig(ctx context.Context, pageSize *int return nil, err } - return topo.ToThreeTierGraph(NAME, instances, false), nil + return topo.ToThreeTierGraph(NAME, instances, p.trimTiers, false), nil } type Provider struct { baseProvider } -func New(clientFactory ClientFactory) *Provider { +func New(clientFactory ClientFactory, trimTiers int) *Provider { return &Provider{ - baseProvider: baseProvider{clientFactory: clientFactory}, + baseProvider: baseProvider{ + clientFactory: clientFactory, + trimTiers: trimTiers, + }, } } diff --git a/pkg/providers/lambdai/provider_sim.go b/pkg/providers/lambdai/provider_sim.go index 2110dce7..619acbe6 100644 --- a/pkg/providers/lambdai/provider_sim.go +++ b/pkg/providers/lambdai/provider_sim.go @@ -123,16 +123,19 @@ func LoaderSim(_ context.Context, cfg providers.Config) (providers.Provider, *ht }, nil } - return NewSim(clientFactory), nil + return NewSim(clientFactory, p.TrimTiers), nil } type simProvider struct { baseProvider } -func NewSim(clientFactory ClientFactory) *simProvider { +func NewSim(clientFactory ClientFactory, trimTiers int) *simProvider { return &simProvider{ - baseProvider: baseProvider{clientFactory: clientFactory}, + baseProvider: baseProvider{ + clientFactory: clientFactory, + trimTiers: trimTiers, + }, } } diff --git a/pkg/providers/lambdai/provider_test.go b/pkg/providers/lambdai/provider_test.go index 3058e9c1..fd5664a7 100644 --- a/pkg/providers/lambdai/provider_test.go +++ b/pkg/providers/lambdai/provider_test.go @@ -12,6 +12,7 @@ import ( "github.com/stretchr/testify/require" "github.com/NVIDIA/topograph/pkg/providers" + "github.com/NVIDIA/topograph/pkg/topology" ) func TestLoader(t *testing.T) { @@ -67,6 +68,20 @@ func TestLoader(t *testing.T) { }, err: "parameters error: missing 'url'", }, + { + name: "Case 5: invalid trimTiers", + config: providers.Config{ + Creds: map[string]any{ + authWorkspaceID: "workspace-123", + authToken: "token-abc", + }, + Params: map[string]any{ + apiBaseURL: "https://api.example.com", + topology.KeyTrimTiers: false, + }, + }, + err: "parameters error: invalid 'trimTiers' value 'false': unsupported type bool", + }, } for _, tt := range tests { diff --git a/pkg/providers/nebius/provider.go b/pkg/providers/nebius/provider.go index 843d55e7..a6193613 100644 --- a/pkg/providers/nebius/provider.go +++ b/pkg/providers/nebius/provider.go @@ -48,6 +48,7 @@ type ClientFactory func(pageSize *int) (Client, error) type baseProvider struct { clientFactory ClientFactory + trimTiers int } type nebiusClient struct { @@ -78,6 +79,11 @@ func Loader(ctx context.Context, config providers.Config) (providers.Provider, * return nil, httpErr } + trimTiers, err := providers.GetTrimTiers(config.Params) + if err != nil { + return nil, httperr.NewError(http.StatusBadRequest, "parameters error: "+err.Error()) + } + // if project ID is not passed in credentials, get it from file projectID, err := providers.StringFromMap(authProjectID, config.Creds, false) if err != nil { @@ -101,7 +107,7 @@ func Loader(ctx context.Context, config providers.Config) (providers.Provider, * }, nil } - return New(clientFactory), nil + return New(clientFactory, trimTiers), nil } func getAuthOption(creds map[string]any) (gosdk.Option, *httperr.Error) { @@ -177,16 +183,19 @@ func (p *baseProvider) GenerateTopologyConfig(ctx context.Context, pageSize *int return nil, err } - return topo.ToThreeTierGraph(NAME, instances, false), nil + return topo.ToThreeTierGraph(NAME, instances, p.trimTiers, false), nil } type Provider struct { baseProvider } -func New(clientFactory ClientFactory) *Provider { +func New(clientFactory ClientFactory, trimTiers int) *Provider { return &Provider{ - baseProvider: baseProvider{clientFactory: clientFactory}, + baseProvider: baseProvider{ + clientFactory: clientFactory, + trimTiers: trimTiers, + }, } } diff --git a/pkg/providers/nebius/provider_sim.go b/pkg/providers/nebius/provider_sim.go index 77881d2c..f79f6735 100644 --- a/pkg/providers/nebius/provider_sim.go +++ b/pkg/providers/nebius/provider_sim.go @@ -131,16 +131,19 @@ func LoaderSim(ctx context.Context, cfg providers.Config) (providers.Provider, * }, nil } - return NewSim(clientFactory), nil + return NewSim(clientFactory, p.TrimTiers), nil } type simProvider struct { baseProvider } -func NewSim(factory ClientFactory) *simProvider { +func NewSim(factory ClientFactory, trimTiers int) *simProvider { return &simProvider{ - baseProvider: baseProvider{clientFactory: factory}, + baseProvider: baseProvider{ + clientFactory: factory, + trimTiers: trimTiers, + }, } } diff --git a/pkg/providers/nebius/provider_test.go b/pkg/providers/nebius/provider_test.go index 7e9c9329..e9fc68b3 100644 --- a/pkg/providers/nebius/provider_test.go +++ b/pkg/providers/nebius/provider_test.go @@ -13,7 +13,7 @@ import ( ) func TestGetAuthOption(t *testing.T) { - testCases := []struct { + tests := []struct { name string creds map[string]any env bool @@ -55,15 +55,15 @@ func TestGetAuthOption(t *testing.T) { }, } - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - if tc.env { + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.env { os.Setenv(authTokenEnvVar, "data") defer os.Unsetenv(authTokenEnvVar) } - _, err := getAuthOption(tc.creds) - if len(tc.err) != 0 { - require.EqualError(t, err, tc.err) + _, err := getAuthOption(tt.creds) + if len(tt.err) != 0 { + require.EqualError(t, err, tt.err) } else { require.Nil(t, err) } @@ -72,31 +72,31 @@ func TestGetAuthOption(t *testing.T) { } func TestGetUserAgentPrefix(t *testing.T) { - testCases := []struct { - name string - version string - want string + tests := []struct { + name string + version string + expected string }{ { - name: "empty version", - version: "", - want: userAgentProduct, + name: "Case 1: empty version", + version: "", + expected: userAgentProduct, }, { - name: "whitespace version", - version: " ", - want: userAgentProduct, + name: "Case 2: whitespace version", + version: " ", + expected: userAgentProduct, }, { - name: "non-empty version", - version: "main", - want: "nvidia-topograph/main", + name: "Case 3: non-empty version", + version: "main", + expected: "nvidia-topograph/main", }, } - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - require.Equal(t, tc.want, getUserAgentPrefix(tc.version)) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + require.Equal(t, tt.expected, getUserAgentPrefix(tt.version)) }) } } diff --git a/pkg/providers/oci/provider.go b/pkg/providers/oci/provider.go index 7d9b8060..474130b4 100644 --- a/pkg/providers/oci/provider.go +++ b/pkg/providers/oci/provider.go @@ -21,6 +21,7 @@ import ( ) type baseProvider struct { + trimTiers int } // Engine support diff --git a/pkg/providers/oci/provider_api.go b/pkg/providers/oci/provider_api.go index d76d2f01..d9f2962b 100644 --- a/pkg/providers/oci/provider_api.go +++ b/pkg/providers/oci/provider_api.go @@ -88,6 +88,11 @@ func LoaderAPI(ctx context.Context, config providers.Config) (providers.Provider return nil, httperr.NewError(http.StatusBadGateway, fmt.Sprintf("unable to get tenancy OCID from config: %v", err)) } + trimTiers, err := providers.GetTrimTiers(config.Params) + if err != nil { + return nil, httperr.NewError(http.StatusBadRequest, "parameters error: "+err.Error()) + } + clientFactory := func(region string, pageSize *int) (Client, error) { identityClient, err := identity.NewIdentityClientWithConfigurationProvider(provider) if err != nil { @@ -113,7 +118,7 @@ func LoaderAPI(ctx context.Context, config providers.Config) (providers.Provider }, nil } - return NewAPI(clientFactory), nil + return NewAPI(clientFactory, trimTiers), nil } func getConfigurationProvider(creds map[string]any) (common.ConfigurationProvider, *httperr.Error) { @@ -163,8 +168,11 @@ func getConfigurationProvider(creds map[string]any) (common.ConfigurationProvide return configProvider, nil } -func NewAPI(clientFactory ClientFactory) *apiProvider { - return &apiProvider{clientFactory: clientFactory} +func NewAPI(clientFactory ClientFactory, trimTiers int) *apiProvider { + return &apiProvider{ + baseProvider: baseProvider{trimTiers: trimTiers}, + clientFactory: clientFactory, + } } func (p *apiProvider) GenerateTopologyConfig(ctx context.Context, pageSize *int, instances []topology.ComputeInstances) (*topology.Vertex, *httperr.Error) { @@ -173,7 +181,7 @@ func (p *apiProvider) GenerateTopologyConfig(ctx context.Context, pageSize *int, return nil, err } - return topo.ToThreeTierGraph(NAME, instances, true), nil + return topo.ToThreeTierGraph(NAME, instances, p.trimTiers, true), nil } func (p *apiProvider) generateInstanceTopology(ctx context.Context, pageSize *int, cis []topology.ComputeInstances) (*topology.ClusterTopology, *httperr.Error) { diff --git a/pkg/providers/oci/provider_api_test.go b/pkg/providers/oci/provider_api_test.go index e4fdd0ab..32d5d64e 100644 --- a/pkg/providers/oci/provider_api_test.go +++ b/pkg/providers/oci/provider_api_test.go @@ -47,7 +47,7 @@ func TestGetConfigurationProvider(t *testing.T) { { name: "Case 1: invalid tenancyId", creds: map[string]any{"tenancyId": false}, - err: "credentials error: 'tenancyId' must be a string", + err: "credentials error: 'tenancyId' must be of type string", }, { name: "Case 2: missing userId", diff --git a/pkg/providers/oci/provider_imds.go b/pkg/providers/oci/provider_imds.go index 30729092..bc3578e4 100644 --- a/pkg/providers/oci/provider_imds.go +++ b/pkg/providers/oci/provider_imds.go @@ -38,8 +38,12 @@ func NamedLoaderIMDS() (string, providers.Loader) { return NAME_IMDS, LoaderIMDS } -func LoaderIMDS(_ context.Context, _ providers.Config) (providers.Provider, *httperr.Error) { - return &imdsProvider{}, nil +func LoaderIMDS(_ context.Context, config providers.Config) (providers.Provider, *httperr.Error) { + trimTiers, err := providers.GetTrimTiers(config.Params) + if err != nil { + return nil, httperr.NewError(http.StatusBadRequest, "parameters error: "+err.Error()) + } + return &imdsProvider{baseProvider: baseProvider{trimTiers: trimTiers}}, nil } func (p *imdsProvider) GenerateTopologyConfig(ctx context.Context, _ *int, instances []topology.ComputeInstances) (*topology.Vertex, *httperr.Error) { @@ -48,7 +52,7 @@ func (p *imdsProvider) GenerateTopologyConfig(ctx context.Context, _ *int, insta return nil, httperr.NewError(http.StatusInternalServerError, err.Error()) } - return topo.ToThreeTierGraph(NAME, instances, true), nil + return topo.ToThreeTierGraph(NAME, instances, p.trimTiers, true), nil } func (p *imdsProvider) generateInstanceTopology(ctx context.Context, cis []topology.ComputeInstances) (*topology.ClusterTopology, error) { diff --git a/pkg/providers/oci/provider_sim.go b/pkg/providers/oci/provider_sim.go index 0b5b542b..6af17fac 100644 --- a/pkg/providers/oci/provider_sim.go +++ b/pkg/providers/oci/provider_sim.go @@ -157,16 +157,19 @@ func LoaderSim(ctx context.Context, cfg providers.Config) (providers.Provider, * }, nil } - return NewSim(clientFactory), nil + return NewSim(clientFactory, p.TrimTiers), nil } type simProvider struct { apiProvider } -func NewSim(factory ClientFactory) *simProvider { +func NewSim(factory ClientFactory, trimTiers int) *simProvider { return &simProvider{ - apiProvider: apiProvider{clientFactory: factory}, + apiProvider: apiProvider{ + baseProvider: baseProvider{trimTiers: trimTiers}, + clientFactory: factory, + }, } } diff --git a/pkg/providers/providers.go b/pkg/providers/providers.go index 7c00d6d9..d348e02a 100644 --- a/pkg/providers/providers.go +++ b/pkg/providers/providers.go @@ -116,16 +116,48 @@ func ReadFile(path string) (string, error) { return string(data), nil } -func StringFromMap(key string, m map[string]any, must bool) (string, error) { +func FromMap[T any](key string, m map[string]any, must bool) (T, error) { + var zero T + v, ok := m[key] if !ok || v == nil { if must { - return "", fmt.Errorf("missing '%s'", key) + return zero, fmt.Errorf("missing '%s'", key) } - return "", nil + return zero, nil + } + + if val, ok := v.(T); ok { + return val, nil + } + + return zero, fmt.Errorf("'%s' must be of type %T", key, zero) +} + +func StringFromMap(key string, m map[string]any, must bool) (string, error) { + return FromMap[string](key, m, must) +} + +func GetTrimTiers(params map[string]any) (int, error) { + v, ok := params[topology.KeyTrimTiers] + if !ok || v == nil { + return 0, nil } - if str, ok := v.(string); ok { - return str, nil + + var trimTiers int + switch val := v.(type) { + case int: + trimTiers = val + case float64: + trimTiers = int(val) + default: + return 0, fmt.Errorf("invalid '%s' value '%v': unsupported type %T", topology.KeyTrimTiers, v, v) + } + + // support up to 2 trimmed tiers: core and spine + if trimTiers < 0 || trimTiers > 2 { + return 0, fmt.Errorf("invalid '%s' value '%v': must be an integer between 0 and 2", topology.KeyTrimTiers, v) } - return "", fmt.Errorf("'%s' must be a string", key) + + return trimTiers, nil } diff --git a/pkg/providers/providers_sim.go b/pkg/providers/providers_sim.go index ff8998bb..fe15418e 100644 --- a/pkg/providers/providers_sim.go +++ b/pkg/providers/providers_sim.go @@ -28,6 +28,7 @@ var ErrAPIError = errors.New("API error") type SimulationParams struct { ModelFileName string `mapstructure:"modelFileName"` APIError int `mapstructure:"api_error"` + TrimTiers int `mapstructure:"trimTiers"` } func GetSimulationParams(params map[string]any) (*SimulationParams, error) { diff --git a/pkg/providers/providers_test.go b/pkg/providers/providers_test.go index ad2c493d..9590e331 100644 --- a/pkg/providers/providers_test.go +++ b/pkg/providers/providers_test.go @@ -21,6 +21,7 @@ import ( "os" "testing" + "github.com/NVIDIA/topograph/pkg/topology" "github.com/stretchr/testify/require" ) @@ -44,7 +45,7 @@ node4: instance4 } func TestReadFile(t *testing.T) { - testCases := []struct { + tests := []struct { name string exists bool data string @@ -65,19 +66,19 @@ func TestReadFile(t *testing.T) { line2`, }, } - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { var path string - if tc.exists { + if tt.exists { f, err := os.CreateTemp("", "test-*") require.NoError(t, err) path = f.Name() defer func() { _ = os.Remove(path) }() defer func() { _ = f.Close() }() - if len(tc.data) != 0 { - n, err := f.WriteString(tc.data) + if len(tt.data) != 0 { + n, err := f.WriteString(tt.data) require.NoError(t, err) - require.Equal(t, len(tc.data), n) + require.Equal(t, len(tt.data), n) err = f.Sync() require.NoError(t, err) } @@ -86,11 +87,165 @@ line2`, } data, err := ReadFile(path) - if tc.err { + if tt.err { require.Error(t, err) } else { require.NoError(t, err) - require.Equal(t, tc.data, data) + require.Equal(t, tt.data, data) + } + }) + } +} + +func TestFromMap(t *testing.T) { + m := map[string]any{ + "name": "switch01", + "count": 3, + "nilv": nil, + } + + tests := []struct { + name string + key string + must bool + expect any + err string + testFn func(string, map[string]any, bool) (any, error) + }{ + { + name: "Case 1: string success", + key: "name", + must: true, + expect: "switch01", + testFn: func(k string, m map[string]any, must bool) (any, error) { + return FromMap[string](k, m, must) + }, + }, + { + name: "Case 2: int success", + key: "count", + must: true, + expect: 3, + testFn: func(k string, m map[string]any, must bool) (any, error) { + return FromMap[int](k, m, must) + }, + }, + { + name: "Case 3: missing optional", + key: "missing", + must: false, + expect: "", + testFn: func(k string, m map[string]any, must bool) (any, error) { + return FromMap[string](k, m, must) + }, + }, + { + name: "Case 4: missing required", + key: "missing", + must: true, + err: "missing 'missing'", + testFn: func(k string, m map[string]any, must bool) (any, error) { + return FromMap[string](k, m, must) + }, + }, + { + name: "Case 5: wrong type", + key: "name", + must: true, + err: "'name' must be of type int", + testFn: func(k string, m map[string]any, must bool) (any, error) { + return FromMap[int](k, m, must) + }, + }, + { + name: "Case 6: nil required", + key: "nilv", + must: true, + err: "missing 'nilv'", + testFn: func(k string, m map[string]any, must bool) (any, error) { + return FromMap[string](k, m, must) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + v, err := tt.testFn(tt.key, m, tt.must) + + if len(tt.err) != 0 { + require.EqualError(t, err, tt.err) + } else { + require.NoError(t, err) + require.Equal(t, tt.expect, v) + } + }) + } +} + +func TestGetTrimTiers(t *testing.T) { + tests := []struct { + name string + params map[string]any + expected int + err string + }{ + { + name: "Case 1: missing key", + params: map[string]any{}, + }, + { + name: "Case 2: nil value", + params: map[string]any{ + topology.KeyTrimTiers: nil, + }, + }, + { + name: "Case 3: int value", + params: map[string]any{ + topology.KeyTrimTiers: 1, + }, + expected: 1, + }, + { + name: "Case 4: float64 value", + params: map[string]any{ + topology.KeyTrimTiers: float64(2), + }, + expected: 2, + }, + { + name: "Case 5: negative value", + params: map[string]any{ + topology.KeyTrimTiers: -1, + }, + err: "invalid 'trimTiers' value '-1': must be an integer between 0 and 2", + }, + { + name: "Case 6: value greater than 2", + params: map[string]any{ + topology.KeyTrimTiers: 3, + }, + err: "invalid 'trimTiers' value '3': must be an integer between 0 and 2", + }, + { + name: "Case 7: unsupported type", + params: map[string]any{ + topology.KeyTrimTiers: "1", + }, + expected: 0, + err: "invalid 'trimTiers' value '1': unsupported type string", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := GetTrimTiers(tt.params) + + if len(tt.err) != 0 { + require.EqualError(t, err, tt.err) + } else { + require.NoError(t, err) + require.Equal(t, tt.expected, result) } }) } diff --git a/pkg/server/grpc_client.go b/pkg/server/grpc_client.go index 6e9883e5..da7e2653 100644 --- a/pkg/server/grpc_client.go +++ b/pkg/server/grpc_client.go @@ -27,11 +27,18 @@ import ( "github.com/NVIDIA/topograph/internal/httperr" pb "github.com/NVIDIA/topograph/pkg/protos" + "github.com/NVIDIA/topograph/pkg/providers" "github.com/NVIDIA/topograph/pkg/topology" ) func forwardRequest(ctx context.Context, tr *topology.Request, url string, cis []topology.ComputeInstances) (*topology.Vertex, *httperr.Error) { klog.Infof("Forwarding request to %s", url) + + trimTiers, err := providers.GetTrimTiers(tr.Provider.Params) + if err != nil { + return nil, httperr.NewError(http.StatusBadRequest, "parameters error: "+err.Error()) + } + conn, err := grpc.NewClient(url, grpc.WithTransportCredentials(insecure.NewCredentials())) if err != nil { return nil, httperr.NewError(http.StatusInternalServerError, fmt.Sprintf("failed to connect to %s: %v", url, err)) @@ -63,7 +70,7 @@ func forwardRequest(ctx context.Context, tr *topology.Request, url string, cis [ } } - return topo.ToThreeTierGraph(tr.Provider.Name, cis, false), nil + return topo.ToThreeTierGraph(tr.Provider.Name, cis, trimTiers, false), nil } func convert(inst *pb.Instance) *topology.InstanceTopology { diff --git a/pkg/server/http_server_test.go b/pkg/server/http_server_test.go index 4fd8b5a0..59fade72 100644 --- a/pkg/server/http_server_test.go +++ b/pkg/server/http_server_test.go @@ -91,6 +91,45 @@ SwitchName=sw12 Nodes=n-[1201-1202] SwitchName=sw13 Nodes=n-[1301-1302] SwitchName=sw14 Nodes=n-[1401-1402] SwitchName=no-topology Nodes=n-CPU +` + + slurmTrimmedTreePayload = ` +{ + "provider": { + "name": "%s", + "params": { + "modelFileName": "../../tests/models/medium.yaml", + "trimTiers": 1 + } + }, + "engine": { + "name": "slurm" + }, + "nodes": [ + { + "region": "R1", + "instances": { + "1101": "n-1101", + "1102": "n-1102", + "1201": "n-1201", + "1202": "n-1202", + "1301": "n-1301", + "1302": "n-1302", + "1401": "n-1401", + "1402": "n-1402", + "1500": "n-CPU" + } + } + ] +} +` + slurmTrimmedTreeConfig = `SwitchName=sw21 Switches=sw[11-12] +SwitchName=sw22 Switches=sw[13-14] +SwitchName=sw11 Nodes=n-[1101-1102] +SwitchName=sw12 Nodes=n-[1201-1202] +SwitchName=sw13 Nodes=n-[1301-1302] +SwitchName=sw14 Nodes=n-[1401-1402] +SwitchName=no-topology Nodes=n-CPU ` slurmBlockPayload = ` @@ -240,14 +279,21 @@ func TestServerLocal(t *testing.T) { expected: slurmTreeConfig, }, { - name: "Case 7: mock Lambda request for tree topology", + name: "Case 7: mock GCP request for trimmed tree topology", + endpoint: "generate", + provider: "gcp-sim", + payload: slurmTrimmedTreePayload, + expected: slurmTrimmedTreeConfig, + }, + { + name: "Case 8: mock Lambda request for tree topology", endpoint: "generate", provider: "lambdai-sim", payload: slurmTreePayload, expected: slurmTreeConfig, }, { - name: "Case 8: mock request for topology with invalid UID", + name: "Case 9: mock request for topology with invalid UID", endpoint: "topology", payload: "invalid-uid", expected: "request ID invalid-uid not found\n", diff --git a/pkg/topology/graph.go b/pkg/topology/graph.go index 052dd23c..be9c1e61 100644 --- a/pkg/topology/graph.go +++ b/pkg/topology/graph.go @@ -90,7 +90,7 @@ func (c *ClusterTopology) Len() int { return len(c.Instances) } -func (c *ClusterTopology) ToThreeTierGraph(provider string, cis []ComputeInstances, normalize bool) *Vertex { +func (c *ClusterTopology) ToThreeTierGraph(provider string, cis []ComputeInstances, trimTiers int, normalize bool) *Vertex { i2n := make(map[string]string) for _, ci := range cis { maps.Copy(i2n, ci.Instances) @@ -123,7 +123,8 @@ func (c *ClusterTopology) ToThreeTierGraph(provider string, cis []ComputeInstanc } swNames := [3]string{inst.LeafName, inst.SpineName, inst.CoreName} - for i, swID := range []string{inst.LeafID, inst.SpineID, inst.CoreID} { + + for i, swID := range trimmedTiers(inst, trimTiers) { if len(swID) == 0 { continue } @@ -225,3 +226,12 @@ func (c *ClusterTopology) Normalize() { c.Instances[i].CoreName = name } } + +func trimmedTiers(inst *InstanceTopology, trimTiers int) []string { + tiers := []string{inst.LeafID, inst.SpineID, inst.CoreID} + n := len(tiers) + for i := 0; i < trimTiers && i < n; i++ { + tiers[n-i-1] = "" + } + return tiers +} diff --git a/pkg/topology/graph_test.go b/pkg/topology/graph_test.go index e9193c4a..0d380d3d 100644 --- a/pkg/topology/graph_test.go +++ b/pkg/topology/graph_test.go @@ -138,7 +138,7 @@ func TestToThreeTierGraphNoNorm(t *testing.T) { Vertices: map[string]*Vertex{TopologyTree: v0, TopologyBlock: blocks}, } - graph := topo.ToThreeTierGraph("test", []ComputeInstances{{Instances: i2n}}, false) + graph := topo.ToThreeTierGraph("test", []ComputeInstances{{Instances: i2n}}, 0, false) require.Equal(t, expected, graph) } @@ -208,7 +208,7 @@ func TestToThreeTierGraphNorm(t *testing.T) { Vertices: map[string]*Vertex{TopologyTree: v0, TopologyBlock: blocks}, } - graph := topo.ToThreeTierGraph("test", []ComputeInstances{{Instances: i2n}}, true) + graph := topo.ToThreeTierGraph("test", []ComputeInstances{{Instances: i2n}}, 0, true) require.Equal(t, expected, graph) inst0 := "Instance:i-001 Leaf:nn-11111111 (switch.1.1) Spine:nn-55555555 (switch.2.1) Core:nn-77777777 (switch.3.1) Accelerator:acc-111111" @@ -217,3 +217,70 @@ func TestToThreeTierGraphNorm(t *testing.T) { inst2 := "Instance:i-003 Leaf:nn-33333333 (switch.1.3) Spine:nn-66666666 (switch.2.2) Core:nn-77777777 (switch.3.1)" require.Equal(t, inst2, topo.Instances[2].String()) } + +func TestTrimTiers(t *testing.T) { + tests := []struct { + name string + trimTiers int + in InstanceTopology + out []string + }{ + { + name: "Case 1: trim none", + trimTiers: 0, + in: InstanceTopology{ + CoreID: "core1", + SpineID: "spine1", + LeafID: "leaf1", + }, + out: []string{"leaf1", "spine1", "core1"}, + }, + { + name: "Case 2: trim 1 tier", + trimTiers: 1, + in: InstanceTopology{ + CoreID: "core1", + SpineID: "spine1", + LeafID: "leaf1", + }, + out: []string{"leaf1", "spine1", ""}, + }, + { + name: "Case 3: trim 2 tiers", + trimTiers: 2, + in: InstanceTopology{ + CoreID: "core1", + SpineID: "spine1", + LeafID: "leaf1", + }, + out: []string{"leaf1", "", ""}, + }, + { + name: "Case 4: trim all tiers", + trimTiers: 3, + in: InstanceTopology{ + CoreID: "core1", + SpineID: "spine1", + LeafID: "leaf1", + }, + out: []string{"", "", ""}, + }, + { + name: "Case 5: trim more than available", + trimTiers: 10, + in: InstanceTopology{ + CoreID: "core1", + SpineID: "spine1", + LeafID: "leaf1", + }, + out: []string{"", "", ""}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + inst := tt.in + require.Equal(t, tt.out, trimmedTiers(&inst, tt.trimTiers)) + }) + } +} diff --git a/pkg/topology/request_test.go b/pkg/topology/request_test.go index 6d4dcf6b..0add4f23 100644 --- a/pkg/topology/request_test.go +++ b/pkg/topology/request_test.go @@ -67,7 +67,9 @@ func TestPayload(t *testing.T) { "accessKeyId": "id", "secretAccessKey": "secret" }, - "params": {} + "params": { + "trimTiers": 2 + } }, "engine": { "name": "slurm", @@ -104,7 +106,7 @@ func TestPayload(t *testing.T) { "accessKeyId": "id", "secretAccessKey": "secret", }, - Params: map[string]any{}, + Params: map[string]any{"trimTiers": float64(2)}, }, Engine: Engine{ Name: "slurm", @@ -136,7 +138,7 @@ func TestPayload(t *testing.T) { print: `TopologyRequest: Provider: aws Credentials: [accessKeyId:*** secretAccessKey:***] - Parameters: [] + Parameters: [trimTiers:2] Engine: slurm Parameters: [block_sizes:30,120 plugin:topology/block reconfigure:true] Nodes: region1: [instance1:node1 instance2:node2 instance3:node3] region2: [instance4:node4 instance5:node5 instance6:node6] diff --git a/pkg/topology/topology.go b/pkg/topology/topology.go index 1c8075d5..aa91ffeb 100644 --- a/pkg/topology/topology.go +++ b/pkg/topology/topology.go @@ -31,6 +31,7 @@ const ( KeyTopoConfigPath = "topologyConfigPath" KeyTopoConfigmapName = "topologyConfigmapName" KeyBlockSizes = "block_sizes" + KeyTrimTiers = "trimTiers" KeyPlugin = "plugin" TopologyTree = "topology/tree"