From e4450c5a451d68745ce0b7ed9de50fcc7ece7177 Mon Sep 17 00:00:00 2001 From: Christian Mouchet Date: Fri, 5 Apr 2024 11:41:10 +0200 Subject: [PATCH] added ckks support and tests at the service layer --- circuits/circuits.go | 5 +- circuits/test_circuits.go | 45 ++- node/node_test.go | 4 +- services/compute/evaluator.go | 4 +- services/compute/participant.go | 48 ++-- services/compute/service_test.go | 476 +++++++++++++++++++++---------- 6 files changed, 405 insertions(+), 177 deletions(-) diff --git a/circuits/circuits.go b/circuits/circuits.go index f76d428..00097b6 100644 --- a/circuits/circuits.go +++ b/circuits/circuits.go @@ -40,7 +40,10 @@ type Runtime interface { // - can omit the session-id part as it wil be automatically resolved by the runtime. NewOperand(OperandLabel) *Operand - EvalLocal(needRlk bool, galKeys []uint64, f func(he.Evaluator) error) error // TODO Eval once freed // TODO NEXT: node allocates evaluators and pass them here. + // EvalLocal is used to perform local operation on the ciphertext. This is where the FHE computation + // is performed. The user must specify the required evaluation keys needed by the function. The provided + // function must not call any other Runtime function (ie., it must be strictly local circuit). + EvalLocal(needRlk bool, galKeys []uint64, f func(he.Evaluator) error) error // DEC performes the decryption of in, with private output to rec. // The decrypted operand is considered an output for the this circuit and the diff --git a/circuits/test_circuits.go b/circuits/test_circuits.go index 2c5ba5e..1c5810f 100644 --- a/circuits/test_circuits.go +++ b/circuits/test_circuits.go @@ -3,11 +3,12 @@ package circuits import ( "github.com/tuneinsight/lattigo/v5/he" "github.com/tuneinsight/lattigo/v5/schemes/bgv" + "github.com/tuneinsight/lattigo/v5/schemes/ckks" ) // TestCircuits contains a set of test circuits for the helium framework. var TestCircuits map[Name]Circuit = map[Name]Circuit{ - "add-2-dec": func(ec Runtime) error { + "bgv-add-2-dec": func(ec Runtime) error { params := ec.Parameters().(bgv.Parameters) @@ -24,7 +25,7 @@ var TestCircuits map[Name]Circuit = map[Name]Circuit{ }) }, - "mul-2-dec": func(ec Runtime) error { + "bgv-mul-2-dec": func(ec Runtime) error { params := ec.Parameters().(bgv.Parameters) @@ -44,4 +45,44 @@ var TestCircuits map[Name]Circuit = map[Name]Circuit{ "smudging": "40.0", }) }, + "ckks-add-2-dec": func(ec Runtime) error { + + params := ec.Parameters().(ckks.Parameters) + + in1, in2 := ec.Input("//p1/in"), ec.Input("//p2/in") + + opRes := ec.NewOperand("//eval/sum") + ec.EvalLocal(false, nil, func(eval he.Evaluator) error { + opRes.Ciphertext = ckks.NewCiphertext(params, 1, params.MaxLevel()) + return eval.Add(in1.Get().Ciphertext, in2.Get().Ciphertext, opRes.Ciphertext) + }) + + return ec.DEC(*opRes, "rec", map[string]string{ + "smudging": "40.0", + }) + }, + + "ckks-mul-2-dec": func(ec Runtime) error { + + params := ec.Parameters().(ckks.Parameters) + + in1, in2 := ec.Input("//p1/in"), ec.Input("//p2/in") + + opRes := ec.NewOperand("//eval/mul") + + err := ec.EvalLocal(true, nil, func(eval he.Evaluator) error { + opRes.Ciphertext = ckks.NewCiphertext(params, 1, params.MaxLevel()) + if err := eval.MulRelin(in1.Get().Ciphertext, in2.Get().Ciphertext, opRes.Ciphertext); err != nil { + return err + } + return nil + }) + if err != nil { + return err + } + + return ec.DEC(*opRes, "rec", map[string]string{ + "smudging": "10.0", + }) + }, } diff --git a/node/node_test.go b/node/node_test.go index e923d0c..0e1c9bc 100644 --- a/node/node_test.go +++ b/node/node_test.go @@ -39,8 +39,8 @@ var testSetupDescription = setup.Description{ } var testCircuits = []TestCircuitSig{ - {Signature: circuits.Signature{Name: "add-2-dec", Args: nil}, ExpResult: 1}, - {Signature: circuits.Signature{Name: "mul-2-dec", Args: nil}, ExpResult: 0}, + {Signature: circuits.Signature{Name: "bgv-add-2-dec", Args: nil}, ExpResult: 1}, + {Signature: circuits.Signature{Name: "bgv-mul-2-dec", Args: nil}, ExpResult: 0}, } var testSettings = []testSetting{ diff --git a/services/compute/evaluator.go b/services/compute/evaluator.go index 78bf3b9..3e3a168 100644 --- a/services/compute/evaluator.go +++ b/services/compute/evaluator.go @@ -74,7 +74,7 @@ func (se *evaluatorRuntime) Init(ctx context.Context, md circuits.Metadata) (err se.CompleteMap = protocols.NewCompletedProt(maps.Values(md.KeySwitchOps)) - se.eval, err = se.getEvaluatorForCircuit(se.sess.Params, md) + se.eval, err = se.getEvaluatorForCircuit(se.sess.Params, md) // TODO pooled evaluators ? if err != nil { se.Logf("failed to get evaluator: %v", err) } @@ -84,7 +84,7 @@ func (se *evaluatorRuntime) Init(ctx context.Context, md circuits.Metadata) (err func (se *evaluatorRuntime) getEvaluatorForCircuit(params session.FHEParameters, md circuits.Metadata) (eval he.Evaluator, err error) { var rlk *rlwe.RelinearizationKey - if md.NeedRlk { // TODO NEXT: this is not populated without circuit parsing. Compute service could have a keyset computed from the setup description. + if md.NeedRlk { rlk, err = se.pkProvider.GetRelinearizationKey(se.ctx) if err != nil { return nil, err diff --git a/services/compute/participant.go b/services/compute/participant.go index 209e694..c68b035 100644 --- a/services/compute/participant.go +++ b/services/compute/participant.go @@ -3,6 +3,7 @@ package compute import ( "context" "fmt" + "math/big" "golang.org/x/exp/maps" @@ -14,6 +15,7 @@ import ( "github.com/tuneinsight/lattigo/v5/he" "github.com/tuneinsight/lattigo/v5/schemes/bgv" "github.com/tuneinsight/lattigo/v5/schemes/ckks" + "github.com/tuneinsight/lattigo/v5/utils/bignum" ) // participantRuntime is a runtime for a participant (a non-evaluator node) in a computation. @@ -125,6 +127,30 @@ func (p *participantRuntime) CompletedProtocol(pd protocols.Descriptor) error { // Circuit Interface +func isValidPlaintext(in interface{}) bool { + return isValidBGVPlaintextType(in) || isValidCKKSPlaintextType(in) +} + +func isValidBGVPlaintextType(in interface{}) bool { + switch in.(type) { + case []uint64, []int64: + return true + default: + return false + + } +} + +func isValidCKKSPlaintextType(in interface{}) bool { + switch in.(type) { + case []complex128, []*bignum.Complex, []float64, []*big.Float: + return true + default: + return false + + } +} + // Input reads an input operand with the given label from the context. func (p *participantRuntime) Input(opl circuits.OperandLabel) *circuits.FutureOperand { @@ -139,32 +165,20 @@ func (p *participantRuntime) Input(opl circuits.OperandLabel) *circuits.FutureOp panic(fmt.Errorf("could not get inputs from input provider: %w", err)) // TODO return error } - isValidPlaintextType := func(in interface{}) bool { - switch in.(type) { - case []uint64, []int64: - return true - default: - return false - - } - } - var inct helium.Ciphertext switch { - case isValidPlaintextType(in): + case isValidPlaintext(in): var inpt *rlwe.Plaintext - switch enc := p.Encoder.(type) { // TODO: lattigo should have a generic Encode interface + switch enc := p.Encoder.(type) { case *bgv.Encoder: inpt = bgv.NewPlaintext(p.sess.Params.(bgv.Parameters), p.sess.Params.GetRLWEParameters().MaxLevel()) err = enc.Encode(in, inpt) case *ckks.Encoder: inpt = ckks.NewPlaintext(p.sess.Params.(ckks.Parameters), p.sess.Params.GetRLWEParameters().MaxLevel()) - err = enc.Encode(in, inpt) - default: - err = fmt.Errorf("invalid encoder type %T", enc) + err = p.Encoder.(*ckks.Encoder).Encode(in, inpt) } if err != nil { - panic(err) + panic(fmt.Errorf("cannot encode input: %w", err)) } in = inpt fallthrough @@ -182,7 +196,7 @@ func (p *participantRuntime) Input(opl circuits.OperandLabel) *circuits.FutureOp CiphertextMetadata: helium.CiphertextMetadata{ID: helium.CiphertextID(opl)}, } default: - panic(fmt.Errorf("invalid input type %T, should be either *rlwe.Plaintext or *rlwe.Ciphertext", in)) + panic(fmt.Errorf("invalid input type %T for session parameters of type %T", in, p.sess.Parameters)) } err = p.trans.PutCiphertext(p.ctx, inct) diff --git a/services/compute/service_test.go b/services/compute/service_test.go index 678c718..614cb2f 100644 --- a/services/compute/service_test.go +++ b/services/compute/service_test.go @@ -17,53 +17,30 @@ import ( "github.com/stretchr/testify/require" "github.com/tuneinsight/lattigo/v5/core/rlwe" "github.com/tuneinsight/lattigo/v5/schemes/bgv" + "github.com/tuneinsight/lattigo/v5/schemes/ckks" "golang.org/x/sync/errgroup" ) -var TestPN12QP109 = bgv.ParametersLiteral{ - LogN: 12, - Q: []uint64{0x7ffffffec001, 0x400000008001}, // 47 + 46 bits - P: []uint64{0xa001}, // 15 bits - PlaintextModulus: 65537, -} - -var rangeParam = []bgv.ParametersLiteral{TestPN12QP109 /* rlwe.TestPN13QP218 , rlwe.TestPN14QP438, rlwe.TestPN15QP880*/} - type TestCircuitSig struct { circuits.Signature - ExpResult uint64 + ExpResult interface{} } type testSetting struct { - N int // N - total parties - T int // T - parties in the access structure - CircuitSigs []TestCircuitSig - Reciever helium.NodeID - Rep int // numer of repetition for each circuit + N int // N - total parties + T int // T - parties in the access structure + Reciever helium.NodeID + Rep int // numer of repetition for each circuit } var testNodeMapping = map[string]helium.NodeID{"p1": "node-0", "p2": "node-1", "eval": "helper"} -var testCircuitSigs = []TestCircuitSig{ - {Signature: circuits.Signature{Name: "add-2-dec", Args: nil}, ExpResult: 1}, - {Signature: circuits.Signature{Name: "mul-2-dec", Args: nil}, ExpResult: 0}, -} - -func NodeIDtoTestInput(nid string) []uint64 { - num := strings.Trim(string(nid), "node-") - i, err := strconv.ParseUint(num, 10, 64) - if err != nil { - panic(err) - } - return []uint64{i} -} - var testSettings = []testSetting{ - {N: 2, CircuitSigs: testCircuitSigs, Reciever: "node-0"}, - {N: 2, CircuitSigs: testCircuitSigs, Reciever: "helper"}, - {N: 3, T: 2, CircuitSigs: testCircuitSigs, Reciever: "node-0"}, - {N: 3, T: 2, CircuitSigs: testCircuitSigs, Reciever: "helper"}, - {N: 3, T: 2, CircuitSigs: testCircuitSigs, Reciever: "helper", Rep: 10}, + {N: 2, Reciever: "node-0"}, + {N: 2, Reciever: "helper"}, + {N: 3, T: 2, Reciever: "node-0"}, + {N: 3, T: 2, Reciever: "helper"}, + {N: 3, T: 2, Reciever: "helper", Rep: 10}, } type testnode struct { @@ -89,153 +66,346 @@ func (tnt *testNodeTrans) GetCiphertext(ctx context.Context, ctID helium.Ciphert return tnt.helperSrv.GetCiphertext(ctx, ctID) } -func TestCloudAssistedCompute(t *testing.T) { - for _, literalParams := range rangeParam { - for _, ts := range testSettings { - if ts.T == 0 { - ts.T = ts.N - } +func TestCloudAssistedComputeBGV(t *testing.T) { - if ts.Rep == 0 { - ts.Rep = 1 - } + bgvParamsLiteral := bgv.ParametersLiteral{ + LogN: 12, + Q: []uint64{0x7ffffffec001, 0x400000008001}, // 47 + 46 bits + P: []uint64{0xa001}, // 15 bits + PlaintextModulus: 65537, + } - t.Run(fmt.Sprintf("NParty=%d/T=%d/rec=%s/rep=%d", ts.N, ts.T, ts.Reciever, ts.Rep), func(t *testing.T) { + var testCircuitSigs = []TestCircuitSig{ + {Signature: circuits.Signature{Name: "bgv-add-2-dec", Args: nil}, ExpResult: uint64(3)}, + {Signature: circuits.Signature{Name: "bgv-mul-2-dec", Args: nil}, ExpResult: uint64(2)}, + } - hid := helium.NodeID("helper") + nodeIDtoTestInput := func(nid string) []uint64 { + num := strings.Trim(string(nid), "node-") + i, err := strconv.ParseUint(num, 10, 64) + if err != nil { + panic(err) + } + return []uint64{i + 1} + } - testSess, err := session.NewTestSession(ts.N, ts.T, literalParams, hid) - if err != nil { - t.Fatal(err) - } - sessParams := testSess.SessParams + for _, ts := range testSettings { + if ts.T == 0 { + ts.T = ts.N + } - ctx := helium.NewBackgroundContext(sessParams.ID) + if ts.Rep == 0 { + ts.Rep = 1 + } - nids := utils.NewSet(sessParams.Nodes) + t.Run(fmt.Sprintf("NParty=%d/T=%d/rec=%s/rep=%d", ts.N, ts.T, ts.Reciever, ts.Rep), func(t *testing.T) { - coord := coordinator.NewTestCoordinator() - protoTrans := protocols.NewTestTransport() + hid := helium.NodeID("helper") - all := make(map[helium.NodeID]*testnode, ts.N+1) - clou := new(testnode) - all["helper"] = clou + testSess, err := session.NewTestSession(ts.N, ts.T, bgvParamsLiteral, hid) + if err != nil { + t.Fatal(err) + } + sessParams := testSess.SessParams - conf := ServiceConfig{ - CircQueueSize: 300, - MaxCircuitEvaluation: 5, - Protocols: protocols.ExecutorConfig{ - SigQueueSize: 300, - MaxProtoPerNode: 1, - MaxAggregation: 1, - MaxParticipation: 1, - }, - } + ctx := helium.NewBackgroundContext(sessParams.ID) + + nids := utils.NewSet(sessParams.Nodes) - srvTrans := &testNodeTrans{Transport: protoTrans} - clou.Coordinator = coord - clou.Service, err = NewComputeService(hid, testSess.HelperSession, conf, testSess, srvTrans) + coord := coordinator.NewTestCoordinator() + protoTrans := protocols.NewTestTransport() + + all := make(map[helium.NodeID]*testnode, ts.N+1) + clou := new(testnode) + all["helper"] = clou + + conf := ServiceConfig{ + CircQueueSize: 300, + MaxCircuitEvaluation: 5, + Protocols: protocols.ExecutorConfig{ + SigQueueSize: 300, + MaxProtoPerNode: 1, + MaxAggregation: 1, + MaxParticipation: 1, + }, + } + + srvTrans := &testNodeTrans{Transport: protoTrans} + clou.Coordinator = coord + clou.Service, err = NewComputeService(hid, testSess.HelperSession, conf, testSess, srvTrans) + if err != nil { + t.Fatal(err) + } + clou.OutputReceiver = make(chan circuits.Output) + clou.Outputs = make(map[helium.CircuitID]circuits.Output) + + clients := make(map[helium.NodeID]*testnode, ts.N) + for nid := range nids { + nid := nid + cli := new(testnode) + cli.Session = testSess.NodeSessions[nid] + srvTrans := &testNodeTrans{Transport: protoTrans.TransportFor(nid), helperSrv: clou.Service} + cli.Service, err = NewComputeService(nid, testSess.NodeSessions[nid], conf, testSess, srvTrans) if err != nil { t.Fatal(err) } - clou.OutputReceiver = make(chan circuits.Output) - clou.Outputs = make(map[helium.CircuitID]circuits.Output) - - clients := make(map[helium.NodeID]*testnode, ts.N) - for nid := range nids { - nid := nid - cli := new(testnode) - cli.Session = testSess.NodeSessions[nid] - srvTrans := &testNodeTrans{Transport: protoTrans.TransportFor(nid), helperSrv: clou.Service} - cli.Service, err = NewComputeService(nid, testSess.NodeSessions[nid], conf, testSess, srvTrans) - if err != nil { - t.Fatal(err) - } - cli.Coordinator = coord.NewPeerCoordinator(nid) + cli.Coordinator = coord.NewPeerCoordinator(nid) - require.Nil(t, err) - cli.InputProvider = func(ctx context.Context, _ helium.CircuitID, ol circuits.OperandLabel, _ session.Session) (any, error) { - return NodeIDtoTestInput(string(nid)), nil - } - cli.OutputReceiver = make(chan circuits.Output) - cli.Outputs = make(map[helium.CircuitID]circuits.Output) - clou.Executor.Register(nid) - clients[nid] = cli - all[nid] = cli + require.Nil(t, err) + cli.InputProvider = func(ctx context.Context, _ helium.CircuitID, ol circuits.OperandLabel, _ session.Session) (any, error) { + return nodeIDtoTestInput(string(nid)), nil } + cli.OutputReceiver = make(chan circuits.Output) + cli.Outputs = make(map[helium.CircuitID]circuits.Output) + clou.Executor.Register(nid) + clients[nid] = cli + all[nid] = cli + } - for _, n := range all { - err = n.RegisterCircuits(circuits.TestCircuits) - require.Nil(t, err) - } + for _, n := range all { + err = n.RegisterCircuits(circuits.TestCircuits) + require.Nil(t, err) + } - g, ctx := errgroup.WithContext(ctx) - // run the nodes - for nid, node := range all { - nid := nid - cli := node - g.Go(func() error { - for out := range cli.OutputReceiver { - cli.Outputs[out.CircuitID] = out - } - return nil - }) - g.Go(func() error { - return errors.WithMessage( - cli.Service.Run(ctx, cli.InputProvider, cli.OutputReceiver, cli.Coordinator), - fmt.Sprintf("at node %s", nid)) - }) - } + g, ctx := errgroup.WithContext(ctx) + // run the nodes + for nid, node := range all { + nid := nid + cli := node + g.Go(func() error { + for out := range cli.OutputReceiver { + cli.Outputs[out.CircuitID] = out + } + return nil + }) + g.Go(func() error { + return errors.WithMessage( + cli.Service.Run(ctx, cli.InputProvider, cli.OutputReceiver, cli.Coordinator), + fmt.Sprintf("at node %s", nid)) + }) + } - cds := make([]circuits.Descriptor, 0, len(ts.CircuitSigs)*ts.Rep) - expResult := make(map[helium.CircuitID]uint64) - for _, tc := range ts.CircuitSigs { - for r := 0; r < ts.Rep; r++ { - cid := helium.CircuitID(fmt.Sprintf("%s-%d", tc.Name, r)) - cd := circuits.Descriptor{ - Signature: tc.Signature, - CircuitID: cid, - NodeMapping: testNodeMapping, - Evaluator: "helper", - } - cd.NodeMapping["rec"] = ts.Reciever - cds = append(cds, cd) - expResult[cid] = tc.ExpResult + cds := make([]circuits.Descriptor, 0, len(testCircuitSigs)*ts.Rep) + expResult := make(map[helium.CircuitID]uint64) + for _, tc := range testCircuitSigs { + for r := 0; r < ts.Rep; r++ { + cid := helium.CircuitID(fmt.Sprintf("%s-%d", tc.Name, r)) + cd := circuits.Descriptor{ + Signature: tc.Signature, + CircuitID: cid, + NodeMapping: testNodeMapping, + Evaluator: "helper", } + cd.NodeMapping["rec"] = ts.Reciever + cds = append(cds, cd) + expResult[cid] = tc.ExpResult.(uint64) } + } + + for _, cd := range cds { + coord.LogEvent(coordinator.Event{CircuitEvent: &circuits.Event{EventType: circuits.Started, Descriptor: cd}}) + } + coord.Close() + + err = g.Wait() // waits for all parties to terminate + require.Nil(t, err) + + bgvParams, err := bgv.NewParameters(testSess.RlweParams, bgvParamsLiteral.PlaintextModulus) + require.Nil(t, err) + encoder := bgv.NewEncoder(bgvParams) + + //fmt.Println("all done") + rec := all[ts.Reciever] + for cid, expRes := range expResult { + out, has := rec.Outputs[cid] + require.True(t, has, "reciever should have an output") + delete(rec.Outputs, cid) + pt := &rlwe.Plaintext{Element: out.Ciphertext.Element, Value: out.Ciphertext.Value[0]} + pt.IsNTT = out.Ciphertext.IsNTT + res := make([]uint64, bgvParams.MaxSlots()) + err := encoder.Decode(pt, res) + require.Nil(t, err) + //fmt.Println(out.OperandLabel, res[:10]) + require.Equal(t, expRes, res[0]) + } + + for nid, n := range all { + require.Empty(t, n.Outputs, "node %s should have no extra outputs", nid) + } - for _, cd := range cds { - coord.LogEvent(coordinator.Event{CircuitEvent: &circuits.Event{EventType: circuits.Started, Descriptor: cd}}) + }) + } +} + +func TestCloudAssistedComputeCKKS(t *testing.T) { + + ckksParamsLiteral := ckks.ParametersLiteral{ + LogN: 12, + Q: []uint64{0x7ffffffec001, 0x400000008001}, // 47 + 46 bits + P: []uint64{0xa001}, // 15 bits + LogDefaultScale: 32, + } + + var testCircuitSigs = []TestCircuitSig{ + {Signature: circuits.Signature{Name: "ckks-add-2-dec", Args: nil}, ExpResult: 1.0}, + {Signature: circuits.Signature{Name: "ckks-mul-2-dec", Args: nil}, ExpResult: 0.2222222222222222}, + } + + nodeIDtoTestInput := func(nid string) []float64 { + num := strings.Trim(string(nid), "node-") + i, err := strconv.ParseUint(num, 10, 64) + if err != nil { + panic(err) + } + return []float64{(float64(i) + 1.0) / 3} + } + + for _, ts := range testSettings { + if ts.T == 0 { + ts.T = ts.N + } + + if ts.Rep == 0 { + ts.Rep = 1 + } + + t.Run(fmt.Sprintf("NParty=%d/T=%d/rec=%s/rep=%d", ts.N, ts.T, ts.Reciever, ts.Rep), func(t *testing.T) { + + hid := helium.NodeID("helper") + + testSess, err := session.NewTestSession(ts.N, ts.T, ckksParamsLiteral, hid) + if err != nil { + t.Fatal(err) + } + sessParams := testSess.SessParams + + ctx := helium.NewBackgroundContext(sessParams.ID) + + nids := utils.NewSet(sessParams.Nodes) + + coord := coordinator.NewTestCoordinator() + protoTrans := protocols.NewTestTransport() + + all := make(map[helium.NodeID]*testnode, ts.N+1) + clou := new(testnode) + all["helper"] = clou + + conf := ServiceConfig{ + CircQueueSize: 300, + MaxCircuitEvaluation: 5, + Protocols: protocols.ExecutorConfig{ + SigQueueSize: 300, + MaxProtoPerNode: 1, + MaxAggregation: 1, + MaxParticipation: 1, + }, + } + + srvTrans := &testNodeTrans{Transport: protoTrans} + clou.Coordinator = coord + clou.Service, err = NewComputeService(hid, testSess.HelperSession, conf, testSess, srvTrans) + if err != nil { + t.Fatal(err) + } + clou.OutputReceiver = make(chan circuits.Output) + clou.Outputs = make(map[helium.CircuitID]circuits.Output) + + clients := make(map[helium.NodeID]*testnode, ts.N) + for nid := range nids { + nid := nid + cli := new(testnode) + cli.Session = testSess.NodeSessions[nid] + srvTrans := &testNodeTrans{Transport: protoTrans.TransportFor(nid), helperSrv: clou.Service} + cli.Service, err = NewComputeService(nid, testSess.NodeSessions[nid], conf, testSess, srvTrans) + if err != nil { + t.Fatal(err) } - coord.Close() + cli.Coordinator = coord.NewPeerCoordinator(nid) - err = g.Wait() // waits for all parties to terminate require.Nil(t, err) + cli.InputProvider = func(ctx context.Context, _ helium.CircuitID, ol circuits.OperandLabel, _ session.Session) (any, error) { + return nodeIDtoTestInput(string(nid)), nil + } + cli.OutputReceiver = make(chan circuits.Output) + cli.Outputs = make(map[helium.CircuitID]circuits.Output) + clou.Executor.Register(nid) + clients[nid] = cli + all[nid] = cli + } - bgvParams, err := bgv.NewParameters(testSess.RlweParams, literalParams.PlaintextModulus) + for _, n := range all { + err = n.RegisterCircuits(circuits.TestCircuits) require.Nil(t, err) - encoder := bgv.NewEncoder(bgvParams) - - fmt.Println("all done") - rec := all[ts.Reciever] - for cid, expRes := range expResult { - out, has := rec.Outputs[cid] - require.True(t, has, "reciever should have an output") - delete(rec.Outputs, cid) - pt := &rlwe.Plaintext{Element: out.Ciphertext.Element, Value: out.Ciphertext.Value[0]} - pt.IsNTT = true - res := make([]uint64, bgvParams.MaxSlots()) - err := encoder.Decode(pt, res) - require.Nil(t, err) - fmt.Println(out.OperandLabel, res[:10]) - require.Equal(t, expRes, res[0]) - } + } + + g, ctx := errgroup.WithContext(ctx) + // run the nodes + for nid, node := range all { + nid := nid + cli := node + g.Go(func() error { + for out := range cli.OutputReceiver { + cli.Outputs[out.CircuitID] = out + } + return nil + }) + g.Go(func() error { + return errors.WithMessage( + cli.Service.Run(ctx, cli.InputProvider, cli.OutputReceiver, cli.Coordinator), + fmt.Sprintf("at node %s", nid)) + }) + } - for nid, n := range all { - require.Empty(t, n.Outputs, "node %s should have no extra outputs", nid) + cds := make([]circuits.Descriptor, 0, len(testCircuitSigs)*ts.Rep) + expResult := make(map[helium.CircuitID]float64) + for _, tc := range testCircuitSigs { + for r := 0; r < ts.Rep; r++ { + cid := helium.CircuitID(fmt.Sprintf("%s-%d", tc.Name, r)) + cd := circuits.Descriptor{ + Signature: tc.Signature, + CircuitID: cid, + NodeMapping: testNodeMapping, + Evaluator: "helper", + } + cd.NodeMapping["rec"] = ts.Reciever + cds = append(cds, cd) + expResult[cid] = tc.ExpResult.(float64) } + } - }) - } + for _, cd := range cds { + coord.LogEvent(coordinator.Event{CircuitEvent: &circuits.Event{EventType: circuits.Started, Descriptor: cd}}) + } + coord.Close() + + err = g.Wait() // waits for all parties to terminate + require.Nil(t, err) + + ckksParams, err := ckks.NewParametersFromLiteral(ckksParamsLiteral) + require.Nil(t, err) + encoder := ckks.NewEncoder(ckksParams) + + //fmt.Println("all done") + rec := all[ts.Reciever] + for cid, expRes := range expResult { + out, has := rec.Outputs[cid] + require.True(t, has, "reciever should have an output") + delete(rec.Outputs, cid) + pt := &rlwe.Plaintext{Element: out.Ciphertext.Element, Value: out.Ciphertext.Value[0]} + pt.IsNTT = out.IsNTT + pt.Scale = out.Scale + + res := make([]float64, ckksParams.MaxSlots()) + err := encoder.Decode(pt, res) + require.Nil(t, err) + //fmt.Printf("%s: exp=%.4f res=%.4f\n", out.OperandLabel, expRes, res[0]) + require.InDelta(t, expRes, res[0], 0.0001) // TODO better bounds + } + + for nid, n := range all { + require.Empty(t, n.Outputs, "node %s should have no extra outputs", nid) + } + + }) } }