diff --git a/go/internal/feast/featurestore.go b/go/internal/feast/featurestore.go index cd9915cabf5..73d6ca618df 100644 --- a/go/internal/feast/featurestore.go +++ b/go/internal/feast/featurestore.go @@ -442,6 +442,11 @@ func (fs *FeatureStore) GetOnlineFeaturesRange( result = append(result, vectors...) } + result, err = onlineserving.KeepOnlyRequestedFeatures(result, featureRefs, featureService, fullFeatureNames) + if err != nil { + return nil, err + } + return result, nil } diff --git a/go/internal/feast/integration_tests/scylladb/scylladb_integration_test.go b/go/internal/feast/integration_tests/scylladb/scylladb_integration_test.go index 6f3810e50e1..43bd89a2314 100644 --- a/go/internal/feast/integration_tests/scylladb/scylladb_integration_test.go +++ b/go/internal/feast/integration_tests/scylladb/scylladb_integration_test.go @@ -98,6 +98,49 @@ func TestGetOnlineFeaturesRange(t *testing.T) { assertResponseData(t, response, featureNames) } +func TestGetOnlineFeaturesRange_includesDuplicatedRequestedFeatures(t *testing.T) { + entities := make(map[string]*types.RepeatedValue) + + entities["index_id"] = &types.RepeatedValue{ + Val: []*types.Value{ + {Val: &types.Value_Int64Val{Int64Val: 1}}, + {Val: &types.Value_Int64Val{Int64Val: 2}}, + {Val: &types.Value_Int64Val{Int64Val: 3}}, + }, + } + + featureNames := []string{"int_val", "int_val"} + + var featureNamesWithFeatureView []string + + for _, featureName := range featureNames { + featureNamesWithFeatureView = append(featureNamesWithFeatureView, "all_dtypes_sorted:"+featureName) + } + + request := &serving.GetOnlineFeaturesRangeRequest{ + Kind: &serving.GetOnlineFeaturesRangeRequest_Features{ + Features: &serving.FeatureList{ + Val: featureNamesWithFeatureView, + }, + }, + Entities: entities, + SortKeyFilters: []*serving.SortKeyFilter{ + { + SortKeyName: "event_timestamp", + Query: &serving.SortKeyFilter_Range{ + Range: &serving.SortKeyFilter_RangeQuery{ + RangeStart: &types.Value{Val: &types.Value_UnixTimestampVal{UnixTimestampVal: 0}}, + }, + }, + }, + }, + Limit: 10, + } + response, err := client.GetOnlineFeaturesRange(ctx, request) + assert.NoError(t, err) + assertResponseData(t, response, featureNames) +} + func TestGetOnlineFeaturesRange_withEmptySortKeyFilter(t *testing.T) { entities := make(map[string]*types.RepeatedValue) diff --git a/go/internal/feast/onlineserving/serving.go b/go/internal/feast/onlineserving/serving.go index fce73d28138..2dc3edee38a 100644 --- a/go/internal/feast/onlineserving/serving.go +++ b/go/internal/feast/onlineserving/serving.go @@ -815,18 +815,24 @@ func getEventTimestamp(timestamps []timestamp.Timestamp, index int) *timestamppb return ×tamppb.Timestamp{} } -func KeepOnlyRequestedFeatures( - vectors []*FeatureVector, +func KeepOnlyRequestedFeatures[T any]( + vectors []T, requestedFeatureRefs []string, featureService *model.FeatureService, - fullFeatureNames bool) ([]*FeatureVector, error) { - vectorsByName := make(map[string]*FeatureVector) - expectedVectors := make([]*FeatureVector, 0) + fullFeatureNames bool) ([]T, error) { + vectorsByName := make(map[string]T) + expectedVectors := make([]T, 0) usedVectors := make(map[string]bool) for _, vector := range vectors { - vectorsByName[vector.Name] = vector + if featureVector, ok := any(vector).(*FeatureVector); ok { + vectorsByName[featureVector.Name] = vector + } else if rangeFeatureVector, ok := any(vector).(*RangeFeatureVector); ok { + vectorsByName[rangeFeatureVector.Name] = vector + } else { + return nil, fmt.Errorf("unsupported vector type: %T", vector) + } } if featureService != nil { @@ -853,8 +859,16 @@ func KeepOnlyRequestedFeatures( // Free arrow arrays for vectors that were not used. for _, vector := range vectors { - if _, ok := usedVectors[vector.Name]; !ok { - vector.Values.Release() + if featureVector, ok := any(vector).(*FeatureVector); ok { + if _, ok := usedVectors[featureVector.Name]; !ok { + featureVector.Values.Release() + } + } else if rangeFeatureVector, ok := any(vector).(*RangeFeatureVector); ok { + if _, ok := usedVectors[rangeFeatureVector.Name]; !ok { + rangeFeatureVector.RangeValues.Release() + } + } else { + return nil, fmt.Errorf("unsupported vector type: %T", vector) } }