/
group_state.go
117 lines (99 loc) · 3.1 KB
/
group_state.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
package session
import (
"github.com/RTann/libsignal-go/protocol/curve"
v1 "github.com/RTann/libsignal-go/protocol/generated/v1"
"github.com/RTann/libsignal-go/protocol/senderkey"
)
// GroupState represents a group session's state.
type GroupState struct {
state *v1.SenderKeyStateStructure
}
type GroupStateConfig struct {
MessageVersion uint8
ChainID uint32
Iteration uint32
ChainKey []byte
SignatureKey curve.PublicKey
SignaturePrivateKey curve.PrivateKey
}
func NewGroupState(cfg GroupStateConfig) *GroupState {
var private []byte
if cfg.SignaturePrivateKey != nil {
private = cfg.SignaturePrivateKey.Bytes()
}
seed := make([]byte, len(cfg.ChainKey))
copy(seed, cfg.ChainKey)
return &GroupState{
state: &v1.SenderKeyStateStructure{
MessageVersion: uint32(cfg.MessageVersion),
ChainId: cfg.ChainID,
SenderChainKey: &v1.SenderKeyStateStructure_SenderChainKey{
Iteration: cfg.Iteration,
Seed: seed,
},
SenderSigningKey: &v1.SenderKeyStateStructure_SenderSigningKey{
Public: cfg.SignatureKey.Bytes(),
Private: private,
},
},
}
}
func (s *GroupState) Version() uint32 {
switch v := s.state.GetMessageVersion(); v {
case 0:
return 3
default:
return v
}
}
func (s *GroupState) ChainID() uint32 {
return s.state.GetChainId()
}
func (s *GroupState) SenderChainKey() senderkey.ChainKey {
chainKey := s.state.GetSenderChainKey()
return senderkey.NewChainKey(chainKey.GetSeed(), chainKey.GetIteration())
}
func (s *GroupState) SetSenderChainKey(chainKey senderkey.ChainKey) {
s.state.SenderChainKey = &v1.SenderKeyStateStructure_SenderChainKey{
Iteration: chainKey.Iteration(),
Seed: chainKey.Seed(),
}
}
func (s *GroupState) PrivateSigningKey() (curve.PrivateKey, error) {
return curve.NewPrivateKey(s.state.GetSenderSigningKey().GetPrivate())
}
func (s *GroupState) PublicSigningKey() (curve.PublicKey, error) {
return curve.NewPublicKey(s.state.GetSenderSigningKey().GetPublic())
}
func (s *GroupState) AddMessageKey(key senderkey.MessageKey) {
msgKeys := &v1.SenderKeyStateStructure_SenderMessageKey{
Iteration: key.Iteration(),
Seed: key.Seed(),
}
s.state.SenderMessageKeys = append(s.state.GetSenderMessageKeys(), msgKeys)
if len(s.state.GetSenderMessageKeys()) > maxMessageKeys {
s.state.GetSenderMessageKeys()[0] = nil
s.state.SenderMessageKeys = s.state.GetSenderMessageKeys()[1:]
}
}
func (s *GroupState) RemoveMessageKeys(iteration uint32) (senderkey.MessageKey, bool, error) {
var messageKey *v1.SenderKeyStateStructure_SenderMessageKey
idx := -1
for i, key := range s.state.GetSenderMessageKeys() {
if key.GetIteration() == iteration {
messageKey = key
idx = i
break
}
}
if idx < 0 {
return senderkey.MessageKey{}, false, nil
}
derived, err := senderkey.DeriveMessageKey(messageKey.GetSeed(), messageKey.GetIteration())
if err != nil {
return senderkey.MessageKey{}, false, err
}
s.state.GetSenderMessageKeys()[idx] = nil
s.state.SenderMessageKeys = append(s.state.GetSenderMessageKeys()[:idx], s.state.GetSenderMessageKeys()[idx+1:]...)
return derived, true, nil
}