diff --git a/Makefile b/Makefile index 5dbd001..754f929 100644 --- a/Makefile +++ b/Makefile @@ -1,6 +1,6 @@ SHELL = /bin/sh -VERSION=1.1.2 +VERSION=1.1.3 BUILD=`git rev-parse HEAD` LDFLAGS=-ldflags "-w -s \ diff --git a/fakeserver/routes.go b/fakeserver/routes.go index 1150ab8..4ad43ce 100644 --- a/fakeserver/routes.go +++ b/fakeserver/routes.go @@ -26,6 +26,11 @@ type v1Assignment struct { Unsynced bool `json:"unsynced"` } +// v2AssignmentOverrideRequestBody is the JSON input for the V2 assignment override endpoint +type v2AssignmentOverrideRequestBody struct { + Assignments []v1Assignment `json:"assignments"` +} + // v1VisitorConfig is the JSON output type for V1 visitor_config endpoints type v1VisitorConfig struct { Splits map[string]*splits.Weights `json:"splits"` @@ -112,6 +117,10 @@ func (s *server) routes() { "/api/v1/assignment_override", postV1AssignmentOverride, ) + s.handlePostReturnNoContent( + "/api/v2/visitors/{v}/assignment_overrides", + postV2AssignmentOverride, + ) s.handleGet( "/api/v1/apps/{a}/versions/{v}/builds/{b}/visitors/{id}/config", getV1AppVisitorConfig, @@ -257,6 +266,35 @@ func postV1AssignmentOverride(r *http.Request) error { return nil } +func postV2AssignmentOverride(r *http.Request) error { + var assignments []v1Assignment + contentType := r.Header.Get("content-type") + switch { + case strings.HasPrefix(contentType, "application/json"): + requestBytes, err := ioutil.ReadAll(r.Body) + if err != nil { + return err + } + var assignmentBody v2AssignmentOverrideRequestBody + err = json.Unmarshal(requestBytes, &assignmentBody) + if err != nil { + return err + } + assignments = assignmentBody.Assignments + default: + return fmt.Errorf("got unexpected content type %s", contentType) + } + storedAssignments, err := fakeassignments.Read() + for _, assignment := range assignments { + (*storedAssignments)[assignment.SplitName] = assignment.Variant + } + err = fakeassignments.Write(storedAssignments) + if err != nil { + return err + } + return nil +} + func getV1AppVisitorConfig() (interface{}, error) { isplitRegistry, err := getV1SplitRegistry() splitRegistry := isplitRegistry.(map[string]*splits.Weights) diff --git a/fakeserver/server_test.go b/fakeserver/server_test.go index 3d35237..bc72dcb 100644 --- a/fakeserver/server_test.go +++ b/fakeserver/server_test.go @@ -1,6 +1,7 @@ package fakeserver import ( + "bytes" "io/ioutil" "log" "net/http" @@ -27,6 +28,10 @@ splits: weights: control: 60 treatment: 40 +- name: test.test2_experiment + weights: + control: 60 + treatment: 40 ` func TestMain(m *testing.M) { @@ -150,3 +155,40 @@ func TestPersistAssignment(t *testing.T) { require.Equal(t, "control", (*assignments)["test.test_experiment"]) }) } + +func TestPersistAssignmentV2(t *testing.T) { + os.Remove("testdata/assignments.yml") + + t.Run("it persists assignments to yaml", func(t *testing.T) { + w := httptest.NewRecorder() + h := createHandler() + + overrides := v2AssignmentOverrideRequestBody{ + Assignments: []v1Assignment{ + v1Assignment{ + SplitName: "test.test_experiment", + Variant: "control", + }, + v1Assignment{ + SplitName: "test.test2_experiment", + Variant: "treatment", + }, + }, + } + data, err := json.Marshal(overrides) + require.Nil(t, err) + + request := httptest.NewRequest("POST", "/api/v2/visitors/1/assignment_overrides", bytes.NewReader(data)) + request.Header.Add("Content-Type", "application/json") + request.Header.Add("Content-Length", strconv.Itoa(len(data))) + + h.ServeHTTP(w, request) + + require.Equal(t, http.StatusNoContent, w.Code) + + assignments, err := fakeassignments.Read() + require.Nil(t, err) + require.Equal(t, "control", (*assignments)["test.test_experiment"]) + require.Equal(t, "treatment", (*assignments)["test.test2_experiment"]) + }) +}