Skip to content

Commit

Permalink
added ckks support and tests at the service layer
Browse files Browse the repository at this point in the history
  • Loading branch information
ChristianMct committed Apr 5, 2024
1 parent e47c18c commit e4450c5
Show file tree
Hide file tree
Showing 6 changed files with 405 additions and 177 deletions.
5 changes: 4 additions & 1 deletion circuits/circuits.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,10 @@ type Runtime interface {
// - can omit the session-id part as it wil be automatically resolved by the runtime.
NewOperand(OperandLabel) *Operand

EvalLocal(needRlk bool, galKeys []uint64, f func(he.Evaluator) error) error // TODO Eval once freed // TODO NEXT: node allocates evaluators and pass them here.
// EvalLocal is used to perform local operation on the ciphertext. This is where the FHE computation
// is performed. The user must specify the required evaluation keys needed by the function. The provided
// function must not call any other Runtime function (ie., it must be strictly local circuit).
EvalLocal(needRlk bool, galKeys []uint64, f func(he.Evaluator) error) error

// DEC performes the decryption of in, with private output to rec.
// The decrypted operand is considered an output for the this circuit and the
Expand Down
45 changes: 43 additions & 2 deletions circuits/test_circuits.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@ package circuits
import (
"github.com/tuneinsight/lattigo/v5/he"
"github.com/tuneinsight/lattigo/v5/schemes/bgv"
"github.com/tuneinsight/lattigo/v5/schemes/ckks"
)

// TestCircuits contains a set of test circuits for the helium framework.
var TestCircuits map[Name]Circuit = map[Name]Circuit{
"add-2-dec": func(ec Runtime) error {
"bgv-add-2-dec": func(ec Runtime) error {

params := ec.Parameters().(bgv.Parameters)

Expand All @@ -24,7 +25,7 @@ var TestCircuits map[Name]Circuit = map[Name]Circuit{
})
},

"mul-2-dec": func(ec Runtime) error {
"bgv-mul-2-dec": func(ec Runtime) error {

params := ec.Parameters().(bgv.Parameters)

Expand All @@ -44,4 +45,44 @@ var TestCircuits map[Name]Circuit = map[Name]Circuit{
"smudging": "40.0",
})
},
"ckks-add-2-dec": func(ec Runtime) error {

params := ec.Parameters().(ckks.Parameters)

in1, in2 := ec.Input("//p1/in"), ec.Input("//p2/in")

opRes := ec.NewOperand("//eval/sum")
ec.EvalLocal(false, nil, func(eval he.Evaluator) error {
opRes.Ciphertext = ckks.NewCiphertext(params, 1, params.MaxLevel())
return eval.Add(in1.Get().Ciphertext, in2.Get().Ciphertext, opRes.Ciphertext)
})

return ec.DEC(*opRes, "rec", map[string]string{
"smudging": "40.0",
})
},

"ckks-mul-2-dec": func(ec Runtime) error {

params := ec.Parameters().(ckks.Parameters)

in1, in2 := ec.Input("//p1/in"), ec.Input("//p2/in")

opRes := ec.NewOperand("//eval/mul")

err := ec.EvalLocal(true, nil, func(eval he.Evaluator) error {
opRes.Ciphertext = ckks.NewCiphertext(params, 1, params.MaxLevel())
if err := eval.MulRelin(in1.Get().Ciphertext, in2.Get().Ciphertext, opRes.Ciphertext); err != nil {
return err
}
return nil
})
if err != nil {
return err
}

return ec.DEC(*opRes, "rec", map[string]string{
"smudging": "10.0",
})
},
}
4 changes: 2 additions & 2 deletions node/node_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ var testSetupDescription = setup.Description{
}

var testCircuits = []TestCircuitSig{
{Signature: circuits.Signature{Name: "add-2-dec", Args: nil}, ExpResult: 1},
{Signature: circuits.Signature{Name: "mul-2-dec", Args: nil}, ExpResult: 0},
{Signature: circuits.Signature{Name: "bgv-add-2-dec", Args: nil}, ExpResult: 1},
{Signature: circuits.Signature{Name: "bgv-mul-2-dec", Args: nil}, ExpResult: 0},
}

var testSettings = []testSetting{
Expand Down
4 changes: 2 additions & 2 deletions services/compute/evaluator.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ func (se *evaluatorRuntime) Init(ctx context.Context, md circuits.Metadata) (err

se.CompleteMap = protocols.NewCompletedProt(maps.Values(md.KeySwitchOps))

se.eval, err = se.getEvaluatorForCircuit(se.sess.Params, md)
se.eval, err = se.getEvaluatorForCircuit(se.sess.Params, md) // TODO pooled evaluators ?
if err != nil {
se.Logf("failed to get evaluator: %v", err)
}
Expand All @@ -84,7 +84,7 @@ func (se *evaluatorRuntime) Init(ctx context.Context, md circuits.Metadata) (err
func (se *evaluatorRuntime) getEvaluatorForCircuit(params session.FHEParameters, md circuits.Metadata) (eval he.Evaluator, err error) {

var rlk *rlwe.RelinearizationKey
if md.NeedRlk { // TODO NEXT: this is not populated without circuit parsing. Compute service could have a keyset computed from the setup description.
if md.NeedRlk {
rlk, err = se.pkProvider.GetRelinearizationKey(se.ctx)
if err != nil {
return nil, err
Expand Down
48 changes: 31 additions & 17 deletions services/compute/participant.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package compute
import (
"context"
"fmt"
"math/big"

"golang.org/x/exp/maps"

Expand All @@ -14,6 +15,7 @@ import (
"github.com/tuneinsight/lattigo/v5/he"
"github.com/tuneinsight/lattigo/v5/schemes/bgv"
"github.com/tuneinsight/lattigo/v5/schemes/ckks"
"github.com/tuneinsight/lattigo/v5/utils/bignum"
)

// participantRuntime is a runtime for a participant (a non-evaluator node) in a computation.
Expand Down Expand Up @@ -125,6 +127,30 @@ func (p *participantRuntime) CompletedProtocol(pd protocols.Descriptor) error {

// Circuit Interface

func isValidPlaintext(in interface{}) bool {
return isValidBGVPlaintextType(in) || isValidCKKSPlaintextType(in)
}

func isValidBGVPlaintextType(in interface{}) bool {
switch in.(type) {
case []uint64, []int64:
return true
default:
return false

}
}

func isValidCKKSPlaintextType(in interface{}) bool {
switch in.(type) {
case []complex128, []*bignum.Complex, []float64, []*big.Float:
return true
default:
return false

}
}

// Input reads an input operand with the given label from the context.
func (p *participantRuntime) Input(opl circuits.OperandLabel) *circuits.FutureOperand {

Expand All @@ -139,32 +165,20 @@ func (p *participantRuntime) Input(opl circuits.OperandLabel) *circuits.FutureOp
panic(fmt.Errorf("could not get inputs from input provider: %w", err)) // TODO return error
}

isValidPlaintextType := func(in interface{}) bool {
switch in.(type) {
case []uint64, []int64:
return true
default:
return false

}
}

var inct helium.Ciphertext
switch {
case isValidPlaintextType(in):
case isValidPlaintext(in):
var inpt *rlwe.Plaintext
switch enc := p.Encoder.(type) { // TODO: lattigo should have a generic Encode interface
switch enc := p.Encoder.(type) {
case *bgv.Encoder:
inpt = bgv.NewPlaintext(p.sess.Params.(bgv.Parameters), p.sess.Params.GetRLWEParameters().MaxLevel())
err = enc.Encode(in, inpt)
case *ckks.Encoder:
inpt = ckks.NewPlaintext(p.sess.Params.(ckks.Parameters), p.sess.Params.GetRLWEParameters().MaxLevel())
err = enc.Encode(in, inpt)
default:
err = fmt.Errorf("invalid encoder type %T", enc)
err = p.Encoder.(*ckks.Encoder).Encode(in, inpt)
}
if err != nil {
panic(err)
panic(fmt.Errorf("cannot encode input: %w", err))
}
in = inpt
fallthrough
Expand All @@ -182,7 +196,7 @@ func (p *participantRuntime) Input(opl circuits.OperandLabel) *circuits.FutureOp
CiphertextMetadata: helium.CiphertextMetadata{ID: helium.CiphertextID(opl)},
}
default:
panic(fmt.Errorf("invalid input type %T, should be either *rlwe.Plaintext or *rlwe.Ciphertext", in))
panic(fmt.Errorf("invalid input type %T for session parameters of type %T", in, p.sess.Parameters))
}

err = p.trans.PutCiphertext(p.ctx, inct)
Expand Down
Loading

0 comments on commit e4450c5

Please sign in to comment.