Skip to content

Commit

Permalink
better retry mechanism and agOutRec
Browse files Browse the repository at this point in the history
  • Loading branch information
ChristianMct committed Apr 22, 2024
1 parent 1adb679 commit c24f65b
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 43 deletions.
55 changes: 29 additions & 26 deletions protocols/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ func (s *Executor) Run(ctx context.Context, trans Transport) error { // TODO: ca
return nil
}

func (s *Executor) runAsAggregator(ctx context.Context, sess *sessions.Session, pd Descriptor, aggOutRec AggregationOutputReceiver) (err error) {
func (s *Executor) runAsAggregator(ctx context.Context, sess *sessions.Session, pd Descriptor) (aggOut AggregationOutput) {

if !s.isAggregatorFor(pd) {
panic(fmt.Errorf("not the aggregator for protocol"))
Expand Down Expand Up @@ -379,7 +379,8 @@ func (s *Executor) runAsAggregator(ctx context.Context, sess *sessions.Session,
if err != nil {
cancelAgg()
clearProtocol()
return fmt.Errorf("cannot get input for protocol: %w", err)
aggOut.Error = fmt.Errorf("cannot get input for protocol: %w", err)
return
}

s.upstream.Outgoing <- Event{EventType: Started, Descriptor: pd}
Expand All @@ -402,11 +403,10 @@ func (s *Executor) runAsAggregator(ctx context.Context, sess *sessions.Session,
}

//go func() {
var agg AggregationOutput
agg.Descriptor = pd
aggOut.Descriptor = pd
for done := false; !done; {
select {
case agg = <-aggregation:
case aggOut = <-aggregation:
done = true
case participantID := <-disconnected:

Expand All @@ -417,32 +417,19 @@ func (s *Executor) runAsAggregator(ctx context.Context, sess *sessions.Session,

s.Logf("node %s disconnected before providing its share, aborting protocol %s", participantID, pd.HID())
done = true
agg.Error = fmt.Errorf("node %s disconnected before providing its share", participantID)
aggOut.Error = fmt.Errorf("node %s disconnected before providing its share", participantID)
}
}
cancelAgg()
clearProtocol()

if agg.Error != nil {
if aggOut.Error != nil {
s.upstream.Outgoing <- Event{EventType: Failed, Descriptor: pd}
return
} else {
s.upstream.Outgoing <- Event{EventType: Completed, Descriptor: pd}
}

err = aggOutRec(ctx, agg)
if err != nil {
return fmt.Errorf("error calling aggregation output receiver: %w", err)
}

if agg.Error != nil {
// re-run the failing sig
sig := pd.Signature
if sig.Type == RKG1 {
sig.Type = RKG
}
return s.runSignature(ctx, sig, aggOutRec)
}

s.runningProtoMu.Lock()
s.completedProtos = append(s.completedProtos, pd)
s.runningProtoMu.Unlock()
Expand Down Expand Up @@ -472,22 +459,38 @@ func (s *Executor) runSignature(ctx context.Context, sig Signature, aggOutRec Ag
}

//s.Logf("getting key operation descriptor: %s", sig)
var aggOut AggregationOutput
for {
// gets a protocol descriptor with available nodes
pd := s.getProtocolDescriptor(sig, sess)

pd := s.getProtocolDescriptor(sig, sess)
// attempts to run the protocol
if aggOut = s.runAsAggregator(ctx, sess, pd); aggOut.Error == nil {
break
}

s.Logf("[%s] error during aggregation: %s, retrying...", pd.HID(), aggOut.Error)
}

//s.Logf("running key operation descriptor: %s", pd)

return s.runAsAggregator(ctx, sess, pd, aggOutRec)
err = aggOutRec(ctx, aggOut)
if err != nil {
return fmt.Errorf("error calling aggregation output receiver: %w", err)
}

return
}

func (s *Executor) RunDescriptorAsAggregator(ctx context.Context, pd Descriptor, aggOutRec AggregationOutputReceiver) (err error) {
func (s *Executor) RunDescriptorAsAggregator(ctx context.Context, pd Descriptor) (aggOut *AggregationOutput, err error) {

sess, has := s.sessProvider.GetSessionFromContext(ctx)
if !has {
return fmt.Errorf("could not extract session from context")
return nil, fmt.Errorf("could not extract session from context")
}

return s.runAsAggregator(ctx, sess, pd, aggOutRec)
agg := s.runAsAggregator(ctx, sess, pd)
return &agg, agg.Error
}

func (s *Executor) runAsParticipant(ctx context.Context, pd Descriptor) error {
Expand Down
8 changes: 1 addition & 7 deletions protocols/executor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,16 +60,10 @@ func TestExecutor(t *testing.T) {
return rkg1AggOut.Share.MHEShare, nil
}

aggOutC := make(chan AggregationOutput, 1) // TODO the next command is blocking, see if we want to make it async
err = helper.RunDescriptorAsAggregator(ctx, r1Pd, func(ctx context.Context, ao AggregationOutput) error {
aggOutC <- ao
return nil
})
rkg1AggOut, err = helper.RunDescriptorAsAggregator(ctx, r1Pd)
if err != nil {
return nil, err
}
rkg1AggOutv := <-aggOutC
rkg1AggOut = &rkg1AggOutv
if rkg1AggOut.Error != nil {
return nil, rkg1AggOut.Error
}
Expand Down
14 changes: 4 additions & 10 deletions services/setup/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -345,20 +345,14 @@ func (s *Service) GetAggregationOutput(ctx context.Context, pd protocols.Descrip
}
} else {
// TODO: prevent double run of protocol
aggOutC := make(chan protocols.AggregationOutput, 1)
err := s.executor.RunDescriptorAsAggregator(ctx, pd, func(ctx context.Context, ao protocols.AggregationOutput) error {
aggOutC <- ao
return nil
})
out, err = s.executor.RunDescriptorAsAggregator(ctx, pd)
if err != nil {
return nil, err
}
aggOut := <-aggOutC
if aggOut.Error != nil {
return nil, fmt.Errorf("aggregation error for %s: %w", pd.HID(), aggOut.Error)
if out.Error != nil {
return nil, fmt.Errorf("aggregation error for %s: %w", pd.HID(), out.Error)
}
out = &aggOut
err = s.resBackend.Put(ctx, pd, aggOut.Share)
err = s.resBackend.Put(ctx, pd, out.Share)
if err != nil {
return nil, err
}
Expand Down

0 comments on commit c24f65b

Please sign in to comment.