diff --git a/Makefile b/Makefile index a5171c9..24614ee 100644 --- a/Makefile +++ b/Makefile @@ -85,7 +85,7 @@ build-so: .PHONY: unit-test-local unit-test-local: - go test -tags envoy${ENVOY_API_VERSION} -v ${GO_MODULES} -gcflags="all=-N -l" -race -covermode=atomic -coverprofile=coverage.out -coverpkg=${PROJECT_NAME}/... + go test -bench=RLS -tags envoy${ENVOY_API_VERSION} -v ${GO_MODULES} -gcflags="all=-N -l" -race -covermode=atomic -coverprofile=coverage.out -coverpkg=${PROJECT_NAME}/... .PHONY: unit-test unit-test: @@ -135,7 +135,7 @@ build-test-so: integration-test: test -d /tmp/htnn_coverage && rm -rf /tmp/htnn_coverage || true if find ./tests/integration -name '*.go' | grep .go > /dev/null; then \ - PROXY_IMAGE=${PROXY_IMAGE} go test -tags integrationtest,envoy${ENVOY_API_VERSION},${EXTRA_GO_BUILD_TAGS} -count 1 -v ./tests/integration/...; \ + PROXY_IMAGE=${PROXY_IMAGE} go test -bench=RLS -tags integrationtest,envoy${ENVOY_API_VERSION},${EXTRA_GO_BUILD_TAGS} -count 1 -v ./tests/integration/...; \ fi # The host of metadata center service, it could be a domain or an IP. diff --git a/pkg/prediction/rls/rls.go b/pkg/prediction/rls/rls.go new file mode 100644 index 0000000..d17e9a3 --- /dev/null +++ b/pkg/prediction/rls/rls.go @@ -0,0 +1,141 @@ +// Copyright The AIGW Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package rls + +const TPOT_COEFF_NUM = 2 + +// TpotRecursiveLeastSquares +type TpotRecursiveLeastSquares struct { + theta []float64 + P [][]float64 + forget float64 +} + +// NewTpotRLS create RLS instance +func NewTpotRLS(forget float64) *TpotRecursiveLeastSquares { + if forget <= 0 || forget > 1 { + forget = 1.0 + } + size := TPOT_COEFF_NUM + 1 + theta := make([]float64, size) + P := make([][]float64, size) + for i := 0; i < size; i++ { + P[i] = make([]float64, size) + P[i][i] = 1e6 + } + return &TpotRecursiveLeastSquares{ + theta: theta, + P: P, + forget: forget, + } +} + +func (r *TpotRecursiveLeastSquares) Update(x []uint64, y float64) { + if len(x) != TPOT_COEFF_NUM { + return + } + + // phi = [x0, x1, 1] + // since the input x is limited, it's safe to convert uint64 to float64 + phi0 := float64(x[0]) + phi1 := float64(x[1]) + phi2 := 1.0 + + P00 := r.P[0][0] + P01 := r.P[0][1] + P02 := r.P[0][2] + P10 := r.P[1][0] + P11 := r.P[1][1] + P12 := r.P[1][2] + P20 := r.P[2][0] + P21 := r.P[2][1] + P22 := r.P[2][2] + + // PHI = P * phi + PHI0 := P00*phi0 + P01*phi1 + P02*phi2 + PHI1 := P10*phi0 + P11*phi1 + P12*phi2 + PHI2 := P20*phi0 + P21*phi1 + P22*phi2 + + // den = forget + phiᵀ * PHI + den := r.forget + + phi0*PHI0 + + phi1*PHI1 + + phi2*PHI2 + invDen := 1.0 / den + + // K = PHI / den + K0 := PHI0 * invDen + K1 := PHI1 * invDen + K2 := PHI2 * invDen + + // yPred = phiᵀ * theta + yPred := phi0*r.theta[0] + phi1*r.theta[1] + phi2*r.theta[2] + e := y - yPred + + // optimize theta + r.theta[0] += K0 * e + r.theta[1] += K1 * e + r.theta[2] += K2 * e + + // update P, P = (P - K * PHIᵀ) / forget + finv := 1.0 / r.forget + + r.P[0][0] = (P00 - K0*PHI0) * finv + r.P[0][1] = (P01 - K0*PHI1) * finv + r.P[0][2] = (P02 - K0*PHI2) * finv + + r.P[1][0] = (P10 - K1*PHI0) * finv + r.P[1][1] = (P11 - K1*PHI1) * finv + r.P[1][2] = (P12 - K1*PHI2) * finv + + r.P[2][0] = (P20 - K2*PHI0) * finv + r.P[2][1] = (P21 - K2*PHI1) * finv + r.P[2][2] = (P22 - K2*PHI2) * finv +} + +// Predict +func (r *TpotRecursiveLeastSquares) Predict(x []uint64) float64 { + if len(x) != TPOT_COEFF_NUM { + return -1 + } + size := TPOT_COEFF_NUM + 1 + phi := make([]uint64, size) + copy(phi, x) + phi[size-1] = 1 + y := float64(phi[0])*r.theta[0] + float64(phi[1])*r.theta[1] + float64(phi[2])*r.theta[2] + return y +} + +// Params return the current coefficients [a1..an, c] +func (r *TpotRecursiveLeastSquares) Params() []float64 { + out := make([]float64, len(r.theta)) + copy(out, r.theta) + return out +} + +func (r *TpotRecursiveLeastSquares) Clone() *TpotRecursiveLeastSquares { + size := TPOT_COEFF_NUM + 1 + newRLS := &TpotRecursiveLeastSquares{ + theta: make([]float64, size), + P: make([][]float64, size), + forget: r.forget, + } + copy(newRLS.theta, r.theta) + for i := range r.P { + newRLS.P[i] = make([]float64, size) + copy(newRLS.P[i], r.P[i]) + } + return newRLS +} diff --git a/pkg/prediction/rls/rls_bench_test.go b/pkg/prediction/rls/rls_bench_test.go new file mode 100644 index 0000000..a59dd6f --- /dev/null +++ b/pkg/prediction/rls/rls_bench_test.go @@ -0,0 +1,66 @@ +// Copyright The AIGW Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package rls + +import ( + "testing" +) + +// Prepare test data +func getTestData() (x1, x2 []uint64, y float64) { + x1 = []uint64{13636} + x2 = []uint64{7997} + y = 1762.002081 + + return x1, x2, y +} + +// Benchmark the Train method +func BenchmarkRLS_Train(b *testing.B) { + model := NewTpotRLS(1.0) + x1, x2, y := getTestData() + dataSize := len(x1) + + b.ResetTimer() + + for i := 0; i < b.N; i++ { + idx := i % dataSize + model.Update([]uint64{x1[idx], x2[idx]}, y) + } +} + +// Predefined model parameters +func getPredefinedParams() []float64 { + return []float64{8.128444e-06, -9.622197e-06, -1.476342e-07, 0.086033, -0.072182, 130.521289} +} + +// Benchmark the Predict method +func BenchmarkRLS_Predict(b *testing.B) { + model := NewTpotRLS(1.0) + model.theta = getPredefinedParams() + x1, x2, y := getTestData() + dataSize := len(x1) + + for i := 0; i < dataSize; i++ { + model.Update([]uint64{x1[i], x2[i]}, y) + } + + b.ResetTimer() + + for i := 0; i < b.N; i++ { + idx := i % dataSize + model.Predict([]uint64{x1[idx], x2[idx]}) + } +} diff --git a/pkg/prediction/rls/rls_test.go b/pkg/prediction/rls/rls_test.go new file mode 100644 index 0000000..254f4aa --- /dev/null +++ b/pkg/prediction/rls/rls_test.go @@ -0,0 +1,82 @@ +// Copyright The AIGW Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package rls + +import ( + "testing" +) + +const ( + UT_RLS_PARAM_SIZE = 2 + UT_RLS_FORGET_RATIO = 1.0 +) + +// Test NewTpotRLS basic init +func TestNewTpotRLS(t *testing.T) { + r := NewTpotRLS(UT_RLS_FORGET_RATIO) + if r == nil { + t.Error("RLS instance should not be nil") + } + if len(r.Params()) != (UT_RLS_PARAM_SIZE + 1) { + t.Fatalf("expected params size %d, got %d", UT_RLS_PARAM_SIZE+1, len(r.Params())) + } +} + +// Test Update and Predict +func TestRLSUpdatePredict(t *testing.T) { + r := NewTpotRLS(UT_RLS_FORGET_RATIO) + + x := []uint64{2, 3} + y := 10.0 + + r.Update(x, y) + p := r.Predict(x) + + if p == 0 { + t.Fatalf("predict should not be zero after update; got %v", p) + } +} + +// Test Predict dim mismatch +func TestPredictDimMismatch(t *testing.T) { + r := NewTpotRLS(UT_RLS_FORGET_RATIO) + + out := r.Predict([]uint64{1}) // wrong dim + if out != -1 { + t.Fatalf("expected -1 on dim mismatch, got %v", out) + } +} + +// Test Update dim mismatch (should not panic) +func TestUpdateDimMismatch(t *testing.T) { + r := NewTpotRLS(UT_RLS_FORGET_RATIO) + + before := r.Params() + r.Update([]uint64{1}, 5.0) // wrong dim + after := r.Params() + for i := 0; i < len(before); i++ { + if before[i] != after[i] { + t.Fatal("coeff should not be updated") + } + } +} + +// Test Params returns copy +func TestParams(t *testing.T) { + r := NewTpotRLS(UT_RLS_FORGET_RATIO) + if len(r.Params()) != UT_RLS_PARAM_SIZE+1 { + t.Fatalf("coeff size invalid, expect %v actual %v", UT_RLS_PARAM_SIZE+1, len(r.Params())) + } +} diff --git a/pkg/prediction/tpot_prediction.go b/pkg/prediction/tpot_prediction.go new file mode 100644 index 0000000..4253ecf --- /dev/null +++ b/pkg/prediction/tpot_prediction.go @@ -0,0 +1,91 @@ +// Copyright The AIGW Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package prediction + +import ( + "sort" + + rls "github.com/aigw-project/aigw/pkg/prediction/rls" +) + +type TpotPrediction interface { + // set threshes for tpot predictor and init rls for each thresh + Init(thresh []uint64) + Clone() TpotPrediction + // return all rls parameters + Params() map[int]map[string]float64 + + Train(batchsize, totalTokenNum uint64, y float64) + Predict(batchsize, totalTokenNum uint64) float64 +} + +// TpotPredictor +type TpotPredictor struct { + thresh []uint64 + rls []*rls.TpotRecursiveLeastSquares +} + +// Init +func (c *TpotPredictor) Init(thresh []uint64) { + c.thresh = make([]uint64, len(thresh)) + copy(c.thresh, thresh) + sort.Slice(c.thresh, func(i, j int) bool { return c.thresh[i] < c.thresh[j] }) + seg := len(c.thresh) + 1 + c.rls = make([]*rls.TpotRecursiveLeastSquares, seg) + for i := 0; i < seg; i++ { + c.rls[i] = rls.NewTpotRLS(1.0) + } +} + +// Params +func (c *TpotPredictor) Params() map[int]map[string]float64 { + m := make(map[int]map[string]float64) + for i, r := range c.rls { + p := r.Params() + m[i] = map[string]float64{"a": p[0], "b": p[1], "c": p[2]} + } + return m +} + +func (c *TpotPredictor) Clone() TpotPrediction { + newPred := &TpotPredictor{ + thresh: make([]uint64, len(c.thresh)), + rls: make([]*rls.TpotRecursiveLeastSquares, len(c.rls)), + } + for i := range newPred.rls { + newPred.rls[i] = c.rls[i].Clone() + } + return newPred +} + +// segment +func (c *TpotPredictor) segment(batchsize uint64) int { + idx := sort.Search(len(c.thresh), func(i int) bool { + return batchsize < c.thresh[i] + }) + return idx +} + +// Train +func (c *TpotPredictor) Train(batchsize, totalTokenNum uint64, y float64) { + seg := c.segment(batchsize) + c.rls[seg].Update([]uint64{batchsize, totalTokenNum}, y) +} + +// Predict +func (c *TpotPredictor) Predict(batchsize, totalTokenNum uint64) float64 { + seg := c.segment(batchsize) + return c.rls[seg].Predict([]uint64{batchsize, totalTokenNum}) +} diff --git a/pkg/prediction/tpot_prediction_test.go b/pkg/prediction/tpot_prediction_test.go new file mode 100644 index 0000000..1d65903 --- /dev/null +++ b/pkg/prediction/tpot_prediction_test.go @@ -0,0 +1,78 @@ +// Copyright The AIGW Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package prediction + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestTpotPredictorInit(t *testing.T) { + p := &TpotPredictor{} + p.Init([]uint64{10, 50, 100}) + + if !assert.Equal(t, p.Params()[0], map[string]float64{"a": 0, "b": 0, "c": 0}) { + t.Fatalf("RLS not initialized correctly") + } +} + +func TestSegment(t *testing.T) { + p := &TpotPredictor{} + p.Init([]uint64{10, 20, 30}) + + assert.Equal(t, 0, p.segment(5)) + assert.Equal(t, 1, p.segment(10)) + assert.Equal(t, 1, p.segment(15)) + assert.Equal(t, 2, p.segment(25)) + assert.Equal(t, 3, p.segment(100)) +} + +func TestTrainAndPredict(t *testing.T) { + p := &TpotPredictor{} + p.Init([]uint64{10}) + + var batchsize uint64 = 5 + var totalTokenNum uint64 = 100 + y := 50.0 + p.Train(batchsize, totalTokenNum, y) + + out := p.Predict(batchsize, totalTokenNum) + if out == 0 { + t.Fatalf("predict should produce non-zero after training; got %v", out) + } +} + +func TestParams(t *testing.T) { + p := &TpotPredictor{} + p.Init([]uint64{10}) + + params := p.Params() + if len(params) != 2 { + t.Fatalf("expected 2 segments, got %d", len(params)) + } + + for _, v := range params { + if _, ok := v["a"]; !ok { + t.Fatalf("missing a") + } + if _, ok := v["b"]; !ok { + t.Fatalf("missing b") + } + if _, ok := v["c"]; !ok { + t.Fatalf("missing c") + } + } +}