-
Notifications
You must be signed in to change notification settings - Fork 0
/
transport.go
139 lines (117 loc) · 3.69 KB
/
transport.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
package node
import (
"fmt"
"github.com/ChristianMct/helium/protocols"
"github.com/ChristianMct/helium/services/compute"
"github.com/ChristianMct/helium/services/setup"
"github.com/ChristianMct/helium/sessions"
"golang.org/x/net/context"
)
type Transport interface {
setup.Transport
compute.Transport
}
type setupTransport struct {
protocolTransport
}
type nodeTransport struct {
setupTransport
computeTransport
}
func newServicesTransport(t Transport) *nodeTransport {
outShares := t.OutgoingShares()
nt := &nodeTransport{
setupTransport: setupTransport{
protocolTransport{
outshares: outShares,
inshares: make(chan protocols.Share),
getAggregationOutput: t.GetAggregationOutput,
},
},
computeTransport: computeTransport{
protocolTransport: protocolTransport{
outshares: outShares,
inshares: make(chan protocols.Share),
getAggregationOutput: t.GetAggregationOutput,
},
putCiphertext: t.PutCiphertext,
getCiphertext: t.GetCiphertext,
},
}
go func() {
inShares := t.IncomingShares()
for s := range inShares {
switch {
case s.ProtocolType.IsSetup():
nt.setupTransport.inshares <- s
case s.ProtocolType.IsCompute():
nt.computeTransport.inshares <- s
default:
panic(fmt.Errorf("unknown protocol type"))
}
}
close(nt.setupTransport.inshares)
close(nt.computeTransport.inshares)
}()
return nt
}
type protocolTransport struct {
outshares chan<- protocols.Share
inshares chan protocols.Share
getAggregationOutput func(context.Context, protocols.Descriptor) (*protocols.AggregationOutput, error)
}
func (n *protocolTransport) IncomingShares() <-chan protocols.Share {
return n.inshares
}
func (n *protocolTransport) OutgoingShares() chan<- protocols.Share {
return n.outshares
}
func (n *protocolTransport) GetAggregationOutput(ctx context.Context, pd protocols.Descriptor) (*protocols.AggregationOutput, error) {
return n.getAggregationOutput(ctx, pd)
}
type computeTransport struct {
protocolTransport
putCiphertext func(ctx context.Context, ct sessions.Ciphertext) error
getCiphertext func(ctx context.Context, ctID sessions.CiphertextID) (*sessions.Ciphertext, error)
}
func (n *computeTransport) PutCiphertext(ctx context.Context, ct sessions.Ciphertext) error {
return n.putCiphertext(ctx, ct)
}
func (n *computeTransport) GetCiphertext(ctx context.Context, ctID sessions.CiphertextID) (*sessions.Ciphertext, error) {
return n.getCiphertext(ctx, ctID)
}
type testTransport struct {
hid sessions.NodeID
helperSetupSrv *setup.Service
helperCompSrv *compute.Service
*protocols.TestTransport
}
func NewTestTransport(hid sessions.NodeID, helperSetupSrv *setup.Service, helperCompSrv *compute.Service) *testTransport {
tt := &testTransport{
hid: hid,
TestTransport: protocols.NewTestTransport(),
helperSetupSrv: helperSetupSrv,
helperCompSrv: helperCompSrv,
}
return tt
}
func (tt testTransport) TransportFor(nid sessions.NodeID) Transport {
if nid == tt.hid {
return tt
}
ttc := &testTransport{
TestTransport: tt.TestTransport.TransportFor(nid),
helperSetupSrv: tt.helperSetupSrv,
helperCompSrv: tt.helperCompSrv,
}
return ttc
}
func (tt testTransport) GetAggregationOutput(ctx context.Context, pd protocols.Descriptor) (*protocols.AggregationOutput, error) {
return tt.helperSetupSrv.GetAggregationOutput(ctx, pd)
}
func (tt testTransport) PutCiphertext(ctx context.Context, ct sessions.Ciphertext) error {
return tt.helperCompSrv.PutCiphertext(ctx, ct)
}
func (tt testTransport) GetCiphertext(ctx context.Context, ctID sessions.CiphertextID) (*sessions.Ciphertext, error) {
return tt.helperCompSrv.GetCiphertext(ctx, ctID)
}