Skip to content

Commit

Permalink
moved entrypoint types and funcs in the main helium package, moved pr…
Browse files Browse the repository at this point in the history
…otobuf-related symbols to api
  • Loading branch information
ChristianMct committed Apr 17, 2024
1 parent 2d7c77c commit 47f2b6e
Show file tree
Hide file tree
Showing 13 changed files with 113 additions and 112 deletions.
File renamed without changes.
60 changes: 30 additions & 30 deletions transport/centralized/messages.go → api/messages.go
Original file line number Diff line number Diff line change
@@ -1,103 +1,103 @@
package centralized
package api

import (
"fmt"

"github.com/ChristianMct/helium/api/pb"
"github.com/ChristianMct/helium/circuit"
"github.com/ChristianMct/helium/node"
"github.com/ChristianMct/helium/protocol"
"github.com/ChristianMct/helium/services/compute"
"github.com/ChristianMct/helium/services/setup"
"github.com/ChristianMct/helium/session"
"github.com/ChristianMct/helium/transport/pb"
"github.com/ChristianMct/helium/utils"
)

func getAPIProtocolEvent(event protocol.Event) *pb.ProtocolEvent {
func GetProtocolEvent(event protocol.Event) *pb.ProtocolEvent {
return &pb.ProtocolEvent{
Type: pb.EventType(event.EventType),
Descriptor_: getAPIProtocolDesc(&event.Descriptor),
Descriptor_: GetProtocolDesc(&event.Descriptor),
}
}

func getProtocolEventFromAPI(apiEvent *pb.ProtocolEvent) protocol.Event {
func ToProtocolEvent(apiEvent *pb.ProtocolEvent) protocol.Event {
return protocol.Event{
EventType: protocol.EventType(apiEvent.Type),
Descriptor: *getProtocolDescFromAPI(apiEvent.Descriptor_),
Descriptor: *ToProtocolDesc(apiEvent.Descriptor_),
}
}

func getAPISetupEvent(event setup.Event) *pb.SetupEvent {
func GetSetupEvent(event setup.Event) *pb.SetupEvent {
return &pb.SetupEvent{
ProtocolEvent: getAPIProtocolEvent(event.Event),
ProtocolEvent: GetProtocolEvent(event.Event),
}
}

func getSetupEventFromAPI(apiEvent *pb.SetupEvent) setup.Event {
func ToSetupEvent(apiEvent *pb.SetupEvent) setup.Event {
return setup.Event{
Event: getProtocolEventFromAPI(apiEvent.ProtocolEvent),
Event: ToProtocolEvent(apiEvent.ProtocolEvent),
}
}

func getAPIComputeEvent(event compute.Event) *pb.ComputeEvent {
func GetComputeEvent(event compute.Event) *pb.ComputeEvent {
apiEvent := &pb.ComputeEvent{}
if event.CircuitEvent != nil {
apiEvent.CircuitEvent = &pb.CircuitEvent{
Type: pb.EventType(event.CircuitEvent.EventType),
Descriptor_: getAPICircuitDesc(event.CircuitEvent.Descriptor),
Descriptor_: GetCircuitDesc(event.CircuitEvent.Descriptor),
}
}
if event.ProtocolEvent != nil {
apiEvent.ProtocolEvent = &pb.ProtocolEvent{
Type: pb.EventType(event.ProtocolEvent.EventType),
Descriptor_: getAPIProtocolDesc(&event.ProtocolEvent.Descriptor),
Descriptor_: GetProtocolDesc(&event.ProtocolEvent.Descriptor),
}
}
return apiEvent
}

func getComputeEventFromAPI(apiEvent *pb.ComputeEvent) compute.Event {
func ToComputeEvent(apiEvent *pb.ComputeEvent) compute.Event {
event := compute.Event{}
if apiEvent.CircuitEvent != nil {
event.CircuitEvent = &circuit.Event{
EventType: circuit.EventType(apiEvent.CircuitEvent.Type),
Descriptor: *getCircuitDescFromAPI(apiEvent.CircuitEvent.Descriptor_),
Descriptor: *ToCircuitDesc(apiEvent.CircuitEvent.Descriptor_),
}
}
if apiEvent.ProtocolEvent != nil {
event.ProtocolEvent = &protocol.Event{
EventType: protocol.EventType(apiEvent.ProtocolEvent.Type),
Descriptor: *getProtocolDescFromAPI(apiEvent.ProtocolEvent.Descriptor_),
Descriptor: *ToProtocolDesc(apiEvent.ProtocolEvent.Descriptor_),
}
}
return event
}

func getAPINodeEvent(event node.Event) *pb.NodeEvent {
func GetNodeEvent(event node.Event) *pb.NodeEvent {
apiEvent := &pb.NodeEvent{}
if event.IsSetup() {
apiEvent.Event = &pb.NodeEvent_SetupEvent{SetupEvent: getAPISetupEvent(*event.SetupEvent)}
apiEvent.Event = &pb.NodeEvent_SetupEvent{SetupEvent: GetSetupEvent(*event.SetupEvent)}
}
if event.IsCompute() {
apiEvent.Event = &pb.NodeEvent_ComputeEvent{ComputeEvent: getAPIComputeEvent(*event.ComputeEvent)}
apiEvent.Event = &pb.NodeEvent_ComputeEvent{ComputeEvent: GetComputeEvent(*event.ComputeEvent)}
}
return apiEvent
}

func getNodeEventFromAPI(apiEvent *pb.NodeEvent) node.Event {
func ToNodeEvent(apiEvent *pb.NodeEvent) node.Event {
event := node.Event{}
switch e := apiEvent.Event.(type) {
case *pb.NodeEvent_SetupEvent:
ev := getSetupEventFromAPI(e.SetupEvent)
ev := ToSetupEvent(e.SetupEvent)
event.SetupEvent = &ev
case *pb.NodeEvent_ComputeEvent:
ev := getComputeEventFromAPI(e.ComputeEvent)
ev := ToComputeEvent(e.ComputeEvent)
event.ComputeEvent = &ev
}
return event
}

func getAPIProtocolDesc(pd *protocol.Descriptor) *pb.ProtocolDescriptor {
func GetProtocolDesc(pd *protocol.Descriptor) *pb.ProtocolDescriptor {
apiDesc := &pb.ProtocolDescriptor{
ProtocolType: pb.ProtocolType(pd.Signature.Type),
Args: make(map[string]string, len(pd.Signature.Args)),
Expand All @@ -113,7 +113,7 @@ func getAPIProtocolDesc(pd *protocol.Descriptor) *pb.ProtocolDescriptor {
return apiDesc
}

func getProtocolDescFromAPI(apiPD *pb.ProtocolDescriptor) *protocol.Descriptor {
func ToProtocolDesc(apiPD *pb.ProtocolDescriptor) *protocol.Descriptor {
desc := &protocol.Descriptor{
Signature: protocol.Signature{Type: protocol.Type(apiPD.ProtocolType), Args: make(map[string]string)},
Aggregator: session.NodeID(apiPD.Aggregator.NodeId),
Expand All @@ -128,7 +128,7 @@ func getProtocolDescFromAPI(apiPD *pb.ProtocolDescriptor) *protocol.Descriptor {
return desc
}

func getAPICircuitDesc(cd circuit.Descriptor) *pb.CircuitDescriptor {
func GetCircuitDesc(cd circuit.Descriptor) *pb.CircuitDescriptor {
apiDesc := &pb.CircuitDescriptor{
CircuitSignature: &pb.CircuitSignature{
Name: string(cd.Name),
Expand All @@ -150,7 +150,7 @@ func getAPICircuitDesc(cd circuit.Descriptor) *pb.CircuitDescriptor {
return apiDesc
}

func getCircuitDescFromAPI(apiCd *pb.CircuitDescriptor) *circuit.Descriptor {
func ToCircuitDesc(apiCd *pb.CircuitDescriptor) *circuit.Descriptor {
cd := &circuit.Descriptor{
Signature: circuit.Signature{
Name: circuit.Name(apiCd.CircuitSignature.Name),
Expand All @@ -172,7 +172,7 @@ func getCircuitDescFromAPI(apiCd *pb.CircuitDescriptor) *circuit.Descriptor {
return cd
}

func getAPIShare(s *protocol.Share) (*pb.Share, error) {
func GetShare(s *protocol.Share) (*pb.Share, error) {
outShareBytes, err := s.MarshalBinary()
if err != nil {
return nil, err
Expand All @@ -191,7 +191,7 @@ func getAPIShare(s *protocol.Share) (*pb.Share, error) {
return apiShare, nil
}

func getShareFromAPI(s *pb.Share) (protocol.Share, error) {
func ToShare(s *pb.Share) (protocol.Share, error) {
desc := s.GetMetadata()
pID := protocol.ID(desc.GetProtocolID().GetProtocolID())
pType := protocol.Type(desc.ProtocolType)
Expand All @@ -218,7 +218,7 @@ func getShareFromAPI(s *pb.Share) (protocol.Share, error) {
return ps, nil
}

func getAPICiphertext(ct *session.Ciphertext) (*pb.Ciphertext, error) {
func GetCiphertext(ct *session.Ciphertext) (*pb.Ciphertext, error) {
ctBytes, err := ct.MarshalBinary()
if err != nil {
return nil, err
Expand All @@ -230,7 +230,7 @@ func getAPICiphertext(ct *session.Ciphertext) (*pb.Ciphertext, error) {
}, nil
}

func getCiphertextFromAPI(apiCt *pb.Ciphertext) (*session.Ciphertext, error) {
func ToCiphertext(apiCt *pb.Ciphertext) (*session.Ciphertext, error) {
var ct session.Ciphertext
ct.CiphertextMetadata.ID = session.CiphertextID(apiCt.Metadata.GetId().CiphertextId)
ct.CiphertextMetadata.Type = session.CiphertextType(apiCt.Metadata.GetType())
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
34 changes: 12 additions & 22 deletions transport/centralized/client.go → client.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package centralized
package helium

import (
"context"
Expand All @@ -9,13 +9,14 @@ import (
"strconv"
"time"

"github.com/ChristianMct/helium/api"
"github.com/ChristianMct/helium/api/pb"
"github.com/ChristianMct/helium/circuit"
"github.com/ChristianMct/helium/coordinator"
"github.com/ChristianMct/helium/node"
"github.com/ChristianMct/helium/protocol"
"github.com/ChristianMct/helium/services/compute"
"github.com/ChristianMct/helium/session"
"github.com/ChristianMct/helium/transport/pb"
"google.golang.org/grpc"
"google.golang.org/grpc/backoff"
"google.golang.org/grpc/credentials/insecure"
Expand All @@ -41,20 +42,8 @@ type HeliumClient struct {
statsHandler
}

func RunHeliumClient(ctx context.Context, config node.Config, nl node.List, app node.App, ip compute.InputProvider) (outs <-chan circuit.Output, err error) {

n, err := node.New(config, nl)
if err != nil {
return nil, err
}

hc := NewHeliumClient(n, config.HelperID, nl.AddressOf(config.HelperID))
if err := hc.Connect(); err != nil {
return nil, err
}

return hc.Run(ctx, app, ip)
}
// Dialer is a function that returns a net.Conn to the provided address.
type Dialer = func(c context.Context, addr string) (net.Conn, error)

// NewHeliumClient creates a new helium client.
func NewHeliumClient(node *node.Node, helperID session.NodeID, helperAddress node.Address) *HeliumClient {
Expand Down Expand Up @@ -165,7 +154,8 @@ func (hc *HeliumClient) Register(ctx context.Context) (upstream *coordinator.Cha
}
return
}
eventsStream <- getNodeEventFromAPI(apiEvent)
eventsStream <- api.ToNodeEvent(apiEvent)

}
}()

Expand All @@ -174,7 +164,7 @@ func (hc *HeliumClient) Register(ctx context.Context) (upstream *coordinator.Cha

// PutShare sends a share to the helium server.
func (hc *HeliumClient) PutShare(ctx context.Context, share protocol.Share) error {
apiShare, err := getAPIShare(&share)
apiShare, err := api.GetShare(&share)
if err != nil {
return err
}
Expand All @@ -184,12 +174,12 @@ func (hc *HeliumClient) PutShare(ctx context.Context, share protocol.Share) erro

// GetAggregationOutput queries and returns the aggregation output for a given protocol descriptor.
func (hc *HeliumClient) GetAggregationOutput(ctx context.Context, pd protocol.Descriptor) (*protocol.AggregationOutput, error) {
apiOut, err := hc.HeliumClient.GetAggregationOutput(hc.outgoingContext(ctx), getAPIProtocolDesc(&pd))
apiOut, err := hc.HeliumClient.GetAggregationOutput(hc.outgoingContext(ctx), api.GetProtocolDesc(&pd))
if err != nil {
return nil, err
}

s, err := getShareFromAPI(apiOut.AggregatedShare)
s, err := api.ToShare(apiOut.AggregatedShare)
if err != nil {
return nil, err
}
Expand All @@ -202,12 +192,12 @@ func (hc *HeliumClient) GetCiphertext(ctx context.Context, ctID session.Cipherte
if err != nil {
return nil, err
}
return getCiphertextFromAPI(apiCt)
return api.ToCiphertext(apiCt)
}

// PutCiphertext sends a ciphertext to the helium server.
func (hc *HeliumClient) PutCiphertext(ctx context.Context, ct session.Ciphertext) error {
apiCt, err := getAPICiphertext(&ct)
apiCt, err := api.GetCiphertext(&ct)
if err != nil {
return err
}
Expand Down
2 changes: 1 addition & 1 deletion transport/centralized/context.go → context.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package centralized
package helium

import (
"context"
Expand Down
6 changes: 3 additions & 3 deletions examples/vec-mul/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@ import (
"fmt"
"log"

"github.com/ChristianMct/helium"
"github.com/ChristianMct/helium/circuit"
"github.com/ChristianMct/helium/node"
"github.com/ChristianMct/helium/objectstore"
"github.com/ChristianMct/helium/protocol"
"github.com/ChristianMct/helium/services/compute"
"github.com/ChristianMct/helium/services/setup"
"github.com/ChristianMct/helium/session"
"github.com/ChristianMct/helium/transport/centralized"
"github.com/tuneinsight/lattigo/v5/core/rlwe"
"github.com/tuneinsight/lattigo/v5/he"
"github.com/tuneinsight/lattigo/v5/mhe"
Expand Down Expand Up @@ -180,9 +180,9 @@ func main() {

// runs the app on a new node
if nodeID == helperID {
cdescs, outs, err = centralized.RunHeliumServer(ctx, config, nodelist, app, ip)
cdescs, outs, err = helium.RunHeliumServer(ctx, config, nodelist, app, ip)
} else {
outs, err = centralized.RunHeliumClient(ctx, config, nodelist, app, ip)
outs, err = helium.RunHeliumClient(ctx, config, nodelist, app, ip)
}
if err != nil {
log.Fatalf("could not run node: %s", err)
Expand Down
44 changes: 43 additions & 1 deletion helium.go
Original file line number Diff line number Diff line change
@@ -1,2 +1,44 @@
// Package helium provides the main types and interfaces for the Helium framework.
package helium

import (
"context"
"net"

"github.com/ChristianMct/helium/circuit"
"github.com/ChristianMct/helium/node"
"github.com/ChristianMct/helium/services/compute"
)

func RunHeliumServer(ctx context.Context, config node.Config, nl node.List, app node.App, ip compute.InputProvider) (cdescs chan<- circuit.Descriptor, outs <-chan circuit.Output, err error) {

helperNode, err := node.New(config, nl)
if err != nil {
return nil, nil, err
}

hsv := NewHeliumServer(helperNode)

lis, err := net.Listen("tcp", string(config.Address))
if err != nil {
return nil, nil, err
}

go hsv.Serve(lis)

return hsv.Run(ctx, app, ip)
}

func RunHeliumClient(ctx context.Context, config node.Config, nl node.List, app node.App, ip compute.InputProvider) (outs <-chan circuit.Output, err error) {

n, err := node.New(config, nl)
if err != nil {
return nil, err
}

hc := NewHeliumClient(n, config.HelperID, nl.AddressOf(config.HelperID))
if err := hc.Connect(); err != nil {
return nil, err
}

return hc.Run(ctx, app, ip)
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package centralized
package helium

import (
"context"
Expand Down
Loading

0 comments on commit 47f2b6e

Please sign in to comment.