Skip to content

Commit

Permalink
handling of failing RKG
Browse files Browse the repository at this point in the history
  • Loading branch information
ChristianMct committed Apr 22, 2024
1 parent d72e844 commit 11533f7
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 90 deletions.
120 changes: 30 additions & 90 deletions protocols/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ func (s *Executor) Run(ctx context.Context, trans Transport) error { // TODO: ca
func (s *Executor) runAsAggregator(ctx context.Context, sess *sessions.Session, pd Descriptor, aggOutRec AggregationOutputReceiver) (err error) {

if !s.isAggregatorFor(pd) {
return fmt.Errorf("not the aggregator for protocol")
panic(fmt.Errorf("not the aggregator for protocol"))
}

proto, err := NewProtocol(pd, sess)
Expand All @@ -332,6 +332,7 @@ func (s *Executor) runAsAggregator(ctx context.Context, sess *sessions.Session,
var aggregation <-chan AggregationOutput
var disconnected chan sessions.NodeID
s.runningProtoMu.Lock()
s.connectedNodesMu.RLock()
incoming := make(chan Share)
disconnected = make(chan sessions.NodeID, len(pd.Participants))
s.runningProtos[pid] = struct {
Expand All @@ -343,7 +344,28 @@ func (s *Executor) runAsAggregator(ctx context.Context, sess *sessions.Session,
incoming: incoming,
disconnected: disconnected,
}

// nodes might have disconnected before the protocol started
for _, nid := range pd.Participants {
if _, has := s.connectedNodes[nid]; !has {
disconnected <- nid
}
}
s.connectedNodesMu.RUnlock()
s.runningProtoMu.Unlock()

clearProtocol := func() {
s.connectedNodesMu.Lock()
for _, part := range pd.Participants {
s.connectedNodes[part].Remove(pd.ID())
}
s.connectedNodesMu.Unlock()
s.connectedNodesCond.Broadcast()

s.runningProtoMu.Lock()
delete(s.runningProtos, pid)
s.runningProtoMu.Unlock()
}
//s.runningProtoWg.Add(1)

// runs the aggregation
Expand All @@ -355,6 +377,7 @@ func (s *Executor) runAsAggregator(ctx context.Context, sess *sessions.Session,
input, err := s.inputProvider(ctx, pd)
if err != nil {
cancelAgg()
clearProtocol()
return fmt.Errorf("cannot get input for protocol: %w", err)
}

Expand All @@ -364,16 +387,14 @@ func (s *Executor) runAsAggregator(ctx context.Context, sess *sessions.Session,

sk, err := sess.GetSecretKeyForGroup(pd.Participants) // TODO: cache
if err != nil {
cancelAgg()
return err
panic(err)
}

// runs the share generation and sending to aggregator
share := proto.AllocateShare()
err = proto.GenShare(sk, input, &share)
if err != nil {
cancelAgg()
return err
panic(err)
}
s.transport.OutgoingShares() <- share
s.Logf("completed participation for %s", pd.HID())
Expand All @@ -399,23 +420,18 @@ func (s *Executor) runAsAggregator(ctx context.Context, sess *sessions.Session,
}
}
cancelAgg()
clearProtocol()

if agg.Error != nil {
s.upstream.Outgoing <- Event{EventType: Failed, Descriptor: pd}
} else {
s.upstream.Outgoing <- Event{EventType: Completed, Descriptor: pd}
}

s.connectedNodesMu.Lock()
for _, part := range pd.Participants {
s.connectedNodes[part].Remove(pd.ID())
err = aggOutRec(ctx, agg)
if err != nil {
return fmt.Errorf("error calling aggregation output receiver: %w", err)
}
s.connectedNodesMu.Unlock()
s.connectedNodesCond.Broadcast()

s.runningProtoMu.Lock()
delete(s.runningProtos, pid)
s.runningProtoMu.Unlock()

if agg.Error != nil {
// re-run the failing sig
Expand All @@ -426,10 +442,6 @@ func (s *Executor) runAsAggregator(ctx context.Context, sess *sessions.Session,
return s.runSignature(ctx, sig, aggOutRec)
}

err = aggOutRec(ctx, agg)
if err != nil {
return fmt.Errorf("error calling aggregation output receiver: %w", err)
}
s.runningProtoMu.Lock()
s.completedProtos = append(s.completedProtos, pd)
s.runningProtoMu.Unlock()
Expand Down Expand Up @@ -658,78 +670,6 @@ func (s *Executor) isKeySwitchReceiver(pd Descriptor) bool {
return false
}

// type testCoordinator struct {
// hid session.NodeID
// log []Event
// closed bool
// incoming, outgoing chan Event
// clients []chan Event

// l sync.Mutex
// }

// func NewTestCoordinator(hid session.NodeID) *testCoordinator {
// tc := &testCoordinator{hid: hid,
// log: make([]Event, 0),
// incoming: make(chan Event),
// outgoing: make(chan Event),
// clients: make([]chan Event, 0)}
// go func() {
// for ev := range tc.outgoing {
// tc.l.Lock()
// tc.log = append(tc.log, ev)
// for _, cli := range tc.clients {
// cli <- ev
// }
// tc.l.Unlock()
// }
// tc.l.Lock()
// tc.closed = true
// for _, cli := range tc.clients {
// close(cli)
// }
// tc.l.Unlock()
// }()
// return tc
// }

// func (tc *testCoordinator) Close() {
// close(tc.incoming)
// }

// func (tc *testCoordinator) Register(ctx context.Context) (evChan *EventChannel, present int, err error) {

// nid, has := session.NodeIDFromContext(ctx)
// if !has {
// return nil, 0, fmt.Errorf("no node id found in context")
// }

// if nid == tc.hid {
// return &EventChannel{Incoming: tc.incoming, Outgoing: tc.outgoing}, len(tc.log), nil
// }

// tc.l.Lock()
// p := len(tc.log)
// cliInc, cliOut := make(chan Event, p), make(chan Event)
// for _, ev := range tc.log {
// cliInc <- ev
// }
// if tc.closed {
// close(cliInc)
// } else {
// tc.clients = append(tc.clients, cliInc)
// }
// tc.l.Unlock()

// go func() {
// for ev := range cliOut {
// tc.outgoing <- ev
// }
// }()

// return &EventChannel{Incoming: cliInc, Outgoing: cliOut}, p, nil
// }

type TestTransport struct {
incoming, outgoing chan Share
}
Expand Down
3 changes: 3 additions & 0 deletions services/setup/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,9 @@ func (s *Service) GetAggregationOutput(ctx context.Context, pd protocols.Descrip
return nil, err
}
aggOut := <-aggOutC
if aggOut.Error != nil {
return nil, fmt.Errorf("aggregation error for %s: %w", pd.HID(), aggOut.Error)
}
out = &aggOut
err = s.resBackend.Put(ctx, pd, aggOut.Share)
if err != nil {
Expand Down

0 comments on commit 11533f7

Please sign in to comment.