forked from perlin-network/noise
-
Notifications
You must be signed in to change notification settings - Fork 0
/
mod.go
94 lines (75 loc) · 2.08 KB
/
mod.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
package protocol
import (
"github.com/Yayg/noise"
"github.com/Yayg/noise/log"
"github.com/pkg/errors"
"sync"
"sync/atomic"
)
const (
KeyProtocolCurrentBlockIndex = "protocol.current_block_index"
KeyProtocolEnforceOnce = "protocol.enforce_once"
)
var (
DisconnectPeer = errors.New("peer disconnect requested")
)
type Block interface {
OnRegister(p *Protocol, node *noise.Node)
OnBegin(p *Protocol, peer *noise.Peer) error
OnEnd(p *Protocol, peer *noise.Peer) error
}
type Protocol struct {
blocks []Block
blocksSealed uint32
}
func New() *Protocol {
return &Protocol{}
}
// Register registers a block to this protocol sequentially.
func (p *Protocol) Register(blk Block) *Protocol {
// This is not a strict check. Only here to help users find their mistakes.
if atomic.LoadUint32(&p.blocksSealed) == 1 {
panic("Register() cannot be called after Enforce().")
}
p.blocks = append(p.blocks, blk)
return p
}
// Enforce enforces that all peers of a node follow the given protocol.
func (p *Protocol) Enforce(node *noise.Node) {
atomic.StoreUint32(&p.blocksSealed, 1)
node.LoadOrStore(KeyProtocolEnforceOnce, new(sync.Once)).(*sync.Once).Do(func() {
for _, block := range p.blocks {
block.OnRegister(p, node)
}
node.OnPeerInit(func(node *noise.Node, peer *noise.Peer) error {
go func() {
peer.OnDisconnect(func(node *noise.Node, peer *noise.Peer) error {
blockIndex := peer.LoadOrStore(KeyProtocolCurrentBlockIndex, 0).(int)
if blockIndex >= len(p.blocks) {
return nil
}
return p.blocks[blockIndex].OnEnd(p, peer)
})
for {
blockIndex := peer.LoadOrStore(KeyProtocolCurrentBlockIndex, 0).(int)
if blockIndex >= len(p.blocks) {
return
}
err := p.blocks[blockIndex].OnBegin(p, peer)
if err != nil {
switch errors.Cause(err) {
case DisconnectPeer:
peer.Disconnect()
default:
log.Warn().Err(err).Msg("Received an error following protocol.")
}
return
} else {
peer.Set(KeyProtocolCurrentBlockIndex, blockIndex+1)
}
}
}()
return nil
})
})
}