From 338395668fbb8a7819c0fccf54dccaa4d7f0ae9e Mon Sep 17 00:00:00 2001 From: Adam Langley Date: Sat, 9 Nov 2013 13:10:17 -0500 Subject: [PATCH] Wire in new ratchet system. --- client/client.go | 58 +++++++--- client/client_test.go | 119 +++++++++++++++----- client/disk.go | 14 ++- client/disk/client.pb.go | 200 +++++++++++++++++++++++++++++++++ client/disk/client.proto | 33 ++++++ client/gui.go | 14 ++- client/network.go | 112 ++++++++++-------- client/ratchet/ratchet.go | 126 ++++++++++++++++++++- client/ratchet/ratchet_test.go | 27 +++++ protos/const.go | 17 ++- protos/pond.pb.go | 2 +- protos/pond.proto | 2 +- 12 files changed, 618 insertions(+), 106 deletions(-) diff --git a/client/client.go b/client/client.go index d722a6c..f9409c1 100644 --- a/client/client.go +++ b/client/client.go @@ -69,11 +69,11 @@ import ( "sync" "time" - "code.google.com/p/go.crypto/curve25519" "code.google.com/p/goprotobuf/proto" "github.com/agl/ed25519" "github.com/agl/pond/bbssig" "github.com/agl/pond/client/disk" + "github.com/agl/pond/client/ratchet" "github.com/agl/pond/panda" pond "github.com/agl/pond/protos" ) @@ -200,6 +200,10 @@ type client struct { // nowFunc is a function that, if not nil, will be used by the GUI to // get the current time. This is used in testing. nowFunc func() time.Time + + // simulateOldClient causes the client to act like a pre-ratchet client + // for testing purposes. + simulateOldClient bool } // UI abstracts behaviour that is specific to a given interface (GUI or CLI). @@ -406,11 +410,14 @@ type Contact struct { // exchange failed. pandaResult string - lastDHPrivate [32]byte - currentDHPrivate [32]byte - + // Members for the old ratchet. + lastDHPrivate [32]byte + currentDHPrivate [32]byte theirLastDHPublic [32]byte theirCurrentDHPublic [32]byte + + // New ratchet support. + ratchet *ratchet.Ratchet } // previousTagLifetime contains the amount of time that we'll store a previous @@ -679,7 +686,7 @@ func (c *client) loadUI() error { return nil } -func (contact *Contact) processKeyExchange(kxsBytes []byte, testing bool) error { +func (contact *Contact) processKeyExchange(kxsBytes []byte, testing, simulateOldClient bool) error { var kxs pond.SignedKeyExchange if err := proto.Unmarshal(kxsBytes, &kxs); err != nil { return err @@ -723,10 +730,25 @@ func (contact *Contact) processKeyExchange(kxsBytes []byte, testing bool) error } copy(contact.theirIdentityPublic[:], kx.IdentityPublic) - if len(kx.Dh) != len(contact.theirCurrentDHPublic) { - return errors.New("invalid public DH value") + if simulateOldClient { + kx.Dh1 = nil + } + + if len(kx.Dh1) == 0 { + // They are using an old-style ratchet. We have to extract the + // private value from the Ratchet in order to use it with the + // old code. + contact.lastDHPrivate = contact.ratchet.GetKXPrivateForTransition() + if len(kx.Dh) != len(contact.theirCurrentDHPublic) { + return errors.New("invalid public DH value") + } + copy(contact.theirCurrentDHPublic[:], kx.Dh) + contact.ratchet = nil + } else { + if err := contact.ratchet.CompleteKeyExchange(&kx); err != nil { + return err + } } - copy(contact.theirCurrentDHPublic[:], kx.Dh) contact.generation = *kx.Generation @@ -774,26 +796,34 @@ func (c *client) registerId(id uint64) { c.usedIds[id] = true } +func (c *client) newRatchet(contact *Contact) *ratchet.Ratchet { + r := ratchet.New(c.rand) + r.MyIdentityPrivate = &c.identity + r.MySigningPublic = &c.pub + r.TheirIdentityPublic = &contact.theirIdentityPublic + r.TheirSigningPublic = &contact.theirPub + return r +} + func (c *client) newKeyExchange(contact *Contact) { var err error - c.randBytes(contact.lastDHPrivate[:]) - c.randBytes(contact.currentDHPrivate[:]) - - var pub [32]byte - curve25519.ScalarBaseMult(&pub, &contact.lastDHPrivate) if contact.groupKey, err = c.groupPriv.NewMember(c.rand); err != nil { panic(err) } + contact.ratchet = c.newRatchet(contact) kx := &pond.KeyExchange{ PublicKey: c.pub[:], IdentityPublic: c.identityPublic[:], Server: proto.String(c.server), - Dh: pub[:], Group: contact.groupKey.Group.Marshal(), GroupKey: contact.groupKey.Marshal(), Generation: proto.Uint32(c.generation), } + contact.ratchet.FillKeyExchange(kx) + if c.simulateOldClient { + kx.Dh1 = nil + } kxBytes, err := proto.Marshal(kx) if err != nil { diff --git a/client/client_test.go b/client/client_test.go index e6f3289..982f902 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -31,6 +31,8 @@ const clientLogToStderr = false // constant can be tweaked to enable logging to stderr. const debugDeadlock = false +const parallel = true + // logActions causes all GUI events to be written to the test log. const logActions = false @@ -336,7 +338,9 @@ func (tc *TestClient) ReloadWithMeetingPlace(mp panda.MeetingPlace) { } func TestOpenClose(t *testing.T) { - t.Parallel() + if parallel { + t.Parallel() + } client, err := NewTestClient(t, "client", nil) if err != nil { @@ -346,7 +350,9 @@ func TestOpenClose(t *testing.T) { } func TestAccountCreation(t *testing.T) { - t.Parallel() + if parallel { + t.Parallel() + } server, err := NewTestServer(t) if err != nil { @@ -480,7 +486,17 @@ func proceedToPaired(t *testing.T, client1, client2 *TestClient, server *TestSer } func TestKeyExchange(t *testing.T) { - t.Parallel() + testKeyExchange(t, false) +} + +func TestKeyExchangeCrossVersion(t *testing.T) { + testKeyExchange(t, true) +} + +func testKeyExchange(t *testing.T, crossVersion bool) { + if parallel { + t.Parallel() + } server, err := NewTestServer(t) if err != nil { @@ -494,6 +510,8 @@ func TestKeyExchange(t *testing.T) { } defer client1.Close() + client1.simulateOldClient = crossVersion + client2, err := NewTestClient(t, "client2", nil) if err != nil { t.Fatal(err) @@ -617,8 +635,8 @@ WaitForAck: func fetchMessage(client *TestClient) (from string, msg *InboxMessage) { ackChan := make(chan bool) - client.fetchNowChan <- ackChan initialInboxLen := len(client.inbox) + client.fetchNowChan <- ackChan WaitForAck: for { @@ -641,7 +659,17 @@ WaitForAck: } func TestMessageExchange(t *testing.T) { - t.Parallel() + testMessageExchange(t, false) +} + +func TestMessageExchangeCrossVersion(t *testing.T) { + testMessageExchange(t, true) +} + +func testMessageExchange(t *testing.T, crossVersion bool) { + if parallel { + t.Parallel() + } server, err := NewTestServer(t) if err != nil { @@ -654,6 +682,7 @@ func TestMessageExchange(t *testing.T) { t.Fatal(err) } defer client1.Close() + client1.simulateOldClient = crossVersion client2, err := NewTestClient(t, "client2", nil) if err != nil { @@ -664,9 +693,11 @@ func TestMessageExchange(t *testing.T) { proceedToPaired(t, client1, client2, server) var initialCurrentDH [32]byte - for _, contact := range client1.contacts { - if contact.name == "client2" { - copy(initialCurrentDH[:], contact.currentDHPrivate[:]) + if crossVersion { + for _, contact := range client1.contacts { + if contact.name == "client2" { + copy(initialCurrentDH[:], contact.currentDHPrivate[:]) + } } } @@ -691,18 +722,22 @@ func TestMessageExchange(t *testing.T) { } } - // Ensure that the DH secrets are advancing. - for _, contact := range client1.contacts { - if contact.name == "client2" { - if bytes.Equal(initialCurrentDH[:], contact.currentDHPrivate[:]) { - t.Fatalf("DH secrets aren't advancing!") + if crossVersion { + // Ensure that the DH secrets are advancing. + for _, contact := range client1.contacts { + if contact.name == "client2" { + if bytes.Equal(initialCurrentDH[:], contact.currentDHPrivate[:]) { + t.Fatalf("DH secrets aren't advancing!") + } } } } } func TestACKs(t *testing.T) { - t.Parallel() + if parallel { + t.Parallel() + } server, err := NewTestServer(t) if err != nil { @@ -768,7 +803,9 @@ WaitForAck: } func TestHalfPairedMessageExchange(t *testing.T) { - t.Parallel() + if parallel { + t.Parallel() + } server, err := NewTestServer(t) if err != nil { @@ -845,7 +882,9 @@ func TestHalfPairedMessageExchange(t *testing.T) { } func TestDraft(t *testing.T) { - t.Parallel() + if parallel { + t.Parallel() + } server, err := NewTestServer(t) if err != nil { @@ -974,7 +1013,9 @@ func TestDraft(t *testing.T) { } func TestDraftDiscard(t *testing.T) { - t.Parallel() + if parallel { + t.Parallel() + } server, err := NewTestServer(t) if err != nil { @@ -1005,7 +1046,9 @@ func TestDraftDiscard(t *testing.T) { } func testDetached(t *testing.T, upload bool) { - t.Parallel() + if parallel { + t.Parallel() + } server, err := NewTestServer(t) if err != nil { @@ -1154,7 +1197,9 @@ func TestUploadDownload(t *testing.T) { } func TestLogOverflow(t *testing.T) { - t.Parallel() + if parallel { + t.Parallel() + } server, err := NewTestServer(t) if err != nil { @@ -1226,7 +1271,9 @@ func TestServerAnnounce(t *testing.T) { } func TestRevoke(t *testing.T) { - t.Parallel() + if parallel { + t.Parallel() + } server, err := NewTestServer(t) if err != nil { @@ -1406,7 +1453,9 @@ func startPANDAKeyExchange(t *testing.T, client *TestClient, server *TestServer, } func TestPANDA(t *testing.T) { - t.Parallel() + if parallel { + t.Parallel() + } server, err := NewTestServer(t) if err != nil { @@ -1473,7 +1522,9 @@ func TestPANDA(t *testing.T) { } func TestReadingOldStateFiles(t *testing.T) { - t.Parallel() + if parallel { + t.Parallel() + } server, err := NewTestServer(t) if err != nil { @@ -1497,7 +1548,9 @@ func TestReadingOldStateFiles(t *testing.T) { func testReplyACKs(t *testing.T, reloadDraft bool, abortSend bool) { // Test that a message is acked by sending a reply. If reloadDraft is // true then the message is reloaded as draft before sending. - t.Parallel() + if parallel { + t.Parallel() + } server, err := NewTestServer(t) if err != nil { @@ -1612,7 +1665,9 @@ func TestCliId(t *testing.T) { func TestSendToPendingContact(t *testing.T) { // Test that it's not possible to send a message to a pending contact. - t.Parallel() + if parallel { + t.Parallel() + } server, err := NewTestServer(t) if err != nil { @@ -1648,7 +1703,9 @@ func TestSendToPendingContact(t *testing.T) { func TestDelete(t *testing.T) { // Test that deleting contacts works. - t.Parallel() + if parallel { + t.Parallel() + } server, err := NewTestServer(t) if err != nil { @@ -1724,7 +1781,9 @@ func TestDelete(t *testing.T) { } func TestExpireMessage(t *testing.T) { - t.Parallel() + if parallel { + t.Parallel() + } server, err := NewTestServer(t) if err != nil { @@ -1846,7 +1905,9 @@ func TestExpireMessage(t *testing.T) { } func TestRetainMessage(t *testing.T) { - t.Parallel() + if parallel { + t.Parallel() + } server, err := NewTestServer(t) if err != nil { @@ -1969,7 +2030,9 @@ func TestRetainMessage(t *testing.T) { } func TestOutboxDeletion(t *testing.T) { - t.Parallel() + if parallel { + t.Parallel() + } server, err := NewTestServer(t) if err != nil { diff --git a/client/disk.go b/client/disk.go index 782a652..773ad07 100644 --- a/client/disk.go +++ b/client/disk.go @@ -97,6 +97,14 @@ func (c *client) unmarshal(state *disk.State) error { } copy(contact.lastDHPrivate[:], cont.LastPrivate) copy(contact.currentDHPrivate[:], cont.CurrentPrivate) + + if cont.Ratchet != nil { + contact.ratchet = c.newRatchet(contact) + if err := contact.ratchet.Unmarshal(cont.Ratchet); err != nil { + return err + } + } + if cont.IsPending != nil && *cont.IsPending { contact.isPending = true continue @@ -254,10 +262,14 @@ func (c *client) marshal() []byte { cont.TheirGroup = contact.myGroupKey.Group.Marshal() cont.TheirServer = proto.String(contact.theirServer) cont.TheirPub = contact.theirPub[:] + cont.Generation = proto.Uint32(contact.generation) + cont.TheirIdentityPublic = contact.theirIdentityPublic[:] cont.TheirLastPublic = contact.theirLastDHPublic[:] cont.TheirCurrentPublic = contact.theirCurrentDHPublic[:] - cont.Generation = proto.Uint32(contact.generation) + } + if contact.ratchet != nil { + cont.Ratchet = contact.ratchet.Marshal(time.Now(), messageLifetime) } for _, prevTag := range contact.previousTags { if time.Since(prevTag.expired) > previousTagLifetime { diff --git a/client/disk/client.pb.go b/client/disk/client.pb.go index 7f2aaa2..73236ee 100644 --- a/client/disk/client.pb.go +++ b/client/disk/client.pb.go @@ -134,6 +134,7 @@ type Contact struct { CurrentPrivate []byte `protobuf:"bytes,12,opt,name=current_private" json:"current_private,omitempty"` TheirLastPublic []byte `protobuf:"bytes,13,opt,name=their_last_public" json:"their_last_public,omitempty"` TheirCurrentPublic []byte `protobuf:"bytes,14,opt,name=their_current_public" json:"their_current_public,omitempty"` + Ratchet *RatchetState `protobuf:"bytes,20,opt,name=ratchet" json:"ratchet,omitempty"` PreviousTags []*Contact_PreviousTag `protobuf:"bytes,17,rep,name=previous_tags" json:"previous_tags,omitempty"` IsPending *bool `protobuf:"varint,15,opt,name=is_pending,def=0" json:"is_pending,omitempty"` XXX_unrecognized []byte `json:"-"` @@ -264,6 +265,13 @@ func (this *Contact) GetTheirCurrentPublic() []byte { return nil } +func (this *Contact) GetRatchet() *RatchetState { + if this != nil { + return this.Ratchet + } + return nil +} + func (this *Contact) GetPreviousTags() []*Contact_PreviousTag { if this != nil { return this.PreviousTags @@ -302,6 +310,198 @@ func (this *Contact_PreviousTag) GetExpired() int64 { return 0 } +type RatchetState struct { + RootKey []byte `protobuf:"bytes,1,req,name=root_key" json:"root_key,omitempty"` + SendHeaderKey []byte `protobuf:"bytes,2,req,name=send_header_key" json:"send_header_key,omitempty"` + RecvHeaderKey []byte `protobuf:"bytes,3,req,name=recv_header_key" json:"recv_header_key,omitempty"` + NextSendHeaderKey []byte `protobuf:"bytes,4,req,name=next_send_header_key" json:"next_send_header_key,omitempty"` + NextRecvHeaderKey []byte `protobuf:"bytes,5,req,name=next_recv_header_key" json:"next_recv_header_key,omitempty"` + SendChainKey []byte `protobuf:"bytes,6,req,name=send_chain_key" json:"send_chain_key,omitempty"` + RecvChainKey []byte `protobuf:"bytes,7,req,name=recv_chain_key" json:"recv_chain_key,omitempty"` + SendRatchetPrivate []byte `protobuf:"bytes,8,req,name=send_ratchet_private" json:"send_ratchet_private,omitempty"` + RecvRatchetPublic []byte `protobuf:"bytes,9,req,name=recv_ratchet_public" json:"recv_ratchet_public,omitempty"` + SendCount *uint32 `protobuf:"varint,10,req,name=send_count" json:"send_count,omitempty"` + RecvCount *uint32 `protobuf:"varint,11,req,name=recv_count" json:"recv_count,omitempty"` + PrevSendCount *uint32 `protobuf:"varint,12,req,name=prev_send_count" json:"prev_send_count,omitempty"` + Ratchet *bool `protobuf:"varint,13,req,name=ratchet" json:"ratchet,omitempty"` + Private0 []byte `protobuf:"bytes,14,opt,name=private0" json:"private0,omitempty"` + Private1 []byte `protobuf:"bytes,15,opt,name=private1" json:"private1,omitempty"` + SavedKeys []*RatchetState_SavedKeys `protobuf:"bytes,16,rep,name=saved_keys" json:"saved_keys,omitempty"` + XXX_unrecognized []byte `json:"-"` +} + +func (this *RatchetState) Reset() { *this = RatchetState{} } +func (this *RatchetState) String() string { return proto.CompactTextString(this) } +func (*RatchetState) ProtoMessage() {} + +func (this *RatchetState) GetRootKey() []byte { + if this != nil { + return this.RootKey + } + return nil +} + +func (this *RatchetState) GetSendHeaderKey() []byte { + if this != nil { + return this.SendHeaderKey + } + return nil +} + +func (this *RatchetState) GetRecvHeaderKey() []byte { + if this != nil { + return this.RecvHeaderKey + } + return nil +} + +func (this *RatchetState) GetNextSendHeaderKey() []byte { + if this != nil { + return this.NextSendHeaderKey + } + return nil +} + +func (this *RatchetState) GetNextRecvHeaderKey() []byte { + if this != nil { + return this.NextRecvHeaderKey + } + return nil +} + +func (this *RatchetState) GetSendChainKey() []byte { + if this != nil { + return this.SendChainKey + } + return nil +} + +func (this *RatchetState) GetRecvChainKey() []byte { + if this != nil { + return this.RecvChainKey + } + return nil +} + +func (this *RatchetState) GetSendRatchetPrivate() []byte { + if this != nil { + return this.SendRatchetPrivate + } + return nil +} + +func (this *RatchetState) GetRecvRatchetPublic() []byte { + if this != nil { + return this.RecvRatchetPublic + } + return nil +} + +func (this *RatchetState) GetSendCount() uint32 { + if this != nil && this.SendCount != nil { + return *this.SendCount + } + return 0 +} + +func (this *RatchetState) GetRecvCount() uint32 { + if this != nil && this.RecvCount != nil { + return *this.RecvCount + } + return 0 +} + +func (this *RatchetState) GetPrevSendCount() uint32 { + if this != nil && this.PrevSendCount != nil { + return *this.PrevSendCount + } + return 0 +} + +func (this *RatchetState) GetRatchet() bool { + if this != nil && this.Ratchet != nil { + return *this.Ratchet + } + return false +} + +func (this *RatchetState) GetPrivate0() []byte { + if this != nil { + return this.Private0 + } + return nil +} + +func (this *RatchetState) GetPrivate1() []byte { + if this != nil { + return this.Private1 + } + return nil +} + +func (this *RatchetState) GetSavedKeys() []*RatchetState_SavedKeys { + if this != nil { + return this.SavedKeys + } + return nil +} + +type RatchetState_SavedKeys struct { + HeaderKey []byte `protobuf:"bytes,1,req,name=header_key" json:"header_key,omitempty"` + MessageKeys []*RatchetState_SavedKeys_MessageKey `protobuf:"bytes,2,rep,name=message_keys" json:"message_keys,omitempty"` + XXX_unrecognized []byte `json:"-"` +} + +func (this *RatchetState_SavedKeys) Reset() { *this = RatchetState_SavedKeys{} } +func (this *RatchetState_SavedKeys) String() string { return proto.CompactTextString(this) } +func (*RatchetState_SavedKeys) ProtoMessage() {} + +func (this *RatchetState_SavedKeys) GetHeaderKey() []byte { + if this != nil { + return this.HeaderKey + } + return nil +} + +func (this *RatchetState_SavedKeys) GetMessageKeys() []*RatchetState_SavedKeys_MessageKey { + if this != nil { + return this.MessageKeys + } + return nil +} + +type RatchetState_SavedKeys_MessageKey struct { + Num *uint32 `protobuf:"varint,1,req,name=num" json:"num,omitempty"` + Key []byte `protobuf:"bytes,2,req,name=key" json:"key,omitempty"` + CreationTime *int64 `protobuf:"varint,3,req,name=creation_time" json:"creation_time,omitempty"` + XXX_unrecognized []byte `json:"-"` +} + +func (this *RatchetState_SavedKeys_MessageKey) Reset() { *this = RatchetState_SavedKeys_MessageKey{} } +func (this *RatchetState_SavedKeys_MessageKey) String() string { return proto.CompactTextString(this) } +func (*RatchetState_SavedKeys_MessageKey) ProtoMessage() {} + +func (this *RatchetState_SavedKeys_MessageKey) GetNum() uint32 { + if this != nil && this.Num != nil { + return *this.Num + } + return 0 +} + +func (this *RatchetState_SavedKeys_MessageKey) GetKey() []byte { + if this != nil { + return this.Key + } + return nil +} + +func (this *RatchetState_SavedKeys_MessageKey) GetCreationTime() int64 { + if this != nil && this.CreationTime != nil { + return *this.CreationTime + } + return 0 +} + type Inbox struct { Id *uint64 `protobuf:"fixed64,1,req,name=id" json:"id,omitempty"` From *uint64 `protobuf:"fixed64,2,req,name=from" json:"from,omitempty"` diff --git a/client/disk/client.proto b/client/disk/client.proto index 5314e8f..bd46145 100644 --- a/client/disk/client.proto +++ b/client/disk/client.proto @@ -49,11 +49,14 @@ message Contact { optional string their_server = 8; optional bytes their_pub = 9; optional bytes their_identity_public = 10; + optional bytes last_private = 11; optional bytes current_private = 12; optional bytes their_last_public = 13; optional bytes their_current_public = 14; + optional RatchetState ratchet = 20; + message PreviousTag { required bytes tag = 1; required int64 expired = 2; @@ -63,6 +66,36 @@ message Contact { optional bool is_pending = 15 [ default = false ]; } +message RatchetState { + required bytes root_key = 1; + required bytes send_header_key = 2; + required bytes recv_header_key = 3; + required bytes next_send_header_key = 4; + required bytes next_recv_header_key = 5; + required bytes send_chain_key = 6; + required bytes recv_chain_key = 7; + required bytes send_ratchet_private = 8; + required bytes recv_ratchet_public = 9; + required uint32 send_count = 10; + required uint32 recv_count = 11; + required uint32 prev_send_count = 12; + required bool ratchet = 13; + + optional bytes private0 = 14; + optional bytes private1 = 15; + + message SavedKeys { + required bytes header_key = 1; + message MessageKey { + required uint32 num = 1; + required bytes key = 2; + required int64 creation_time = 3; + } + repeated MessageKey message_keys = 2; + } + repeated SavedKeys saved_keys = 16; +} + message Inbox { required fixed64 id = 1; required fixed64 from = 2; diff --git a/client/gui.go b/client/gui.go index e0de8e0..ed903da 100644 --- a/client/gui.go +++ b/client/gui.go @@ -2252,7 +2252,7 @@ func (c *guiClient) newContactManual(contact *Contact, existing bool, nextRow in c.gui.Signal() continue } - if err := contact.processKeyExchange(block.Bytes, c.dev); err != nil { + if err := contact.processKeyExchange(block.Bytes, c.dev, c.simulateOldClient); err != nil { c.gui.Actions() <- SetText{name: "error2", text: err.Error()} c.gui.Actions() <- UIError{err} c.gui.Signal() @@ -3164,8 +3164,12 @@ func (c *guiClient) composeUI(draft *Draft, inReplyTo *InboxMessage) interface{} } } - var nextDHPub [32]byte - curve25519.ScalarBaseMult(&nextDHPub, &to.currentDHPrivate) + var myNextDH []byte + if to.ratchet == nil { + var nextDHPub [32]byte + curve25519.ScalarBaseMult(&nextDHPub, &to.currentDHPrivate) + myNextDH = nextDHPub[:] + } var replyToId *uint64 if inReplyTo != nil { @@ -3185,7 +3189,7 @@ func (c *guiClient) composeUI(draft *Draft, inReplyTo *InboxMessage) interface{} Body: []byte(body), BodyEncoding: pond.Message_RAW.Enum(), InReplyTo: replyToId, - MyNextDh: nextDHPub[:], + MyNextDh: myNextDH, Files: draft.attachments, DetachedFiles: draft.detachments, SupportedVersion: proto.Int32(protoVersion), @@ -3256,7 +3260,7 @@ func (c *guiClient) processPANDAUpdate(update pandaUpdate) { contact.pandaKeyExchange = nil contact.pandaShutdownChan = nil - if err := contact.processKeyExchange(update.result, c.dev); err != nil { + if err := contact.processKeyExchange(update.result, c.dev, c.simulateOldClient); err != nil { contact.pandaResult = err.Error() c.contactsUI.SetSubline(contact.id, "failed") c.log.Printf("Key exchange with %s failed: %s", contact.name, err) diff --git a/client/network.go b/client/network.go index 7f87544..1007f15 100644 --- a/client/network.go +++ b/client/network.go @@ -39,8 +39,12 @@ const ( func (c *guiClient) sendAck(msg *InboxMessage) { to := c.contacts[msg.from] - var nextDHPub [32]byte - curve25519.ScalarBaseMult(&nextDHPub, &to.currentDHPrivate) + var myNextDH []byte + if to.ratchet == nil { + var nextDHPub [32]byte + curve25519.ScalarBaseMult(&nextDHPub, &to.currentDHPrivate) + myNextDH = nextDHPub[:] + } id := c.randId() err := c.send(to, &pond.Message{ @@ -48,7 +52,7 @@ func (c *guiClient) sendAck(msg *InboxMessage) { Time: proto.Int64(time.Now().Unix()), Body: make([]byte, 0), BodyEncoding: pond.Message_RAW.Enum(), - MyNextDh: nextDHPub[:], + MyNextDh: myNextDH, InReplyTo: msg.message.Id, SupportedVersion: proto.Int32(protoVersion), }) @@ -74,34 +78,39 @@ func (c *guiClient) send(to *Contact, message *pond.Message) error { copy(plaintext[4:], messageBytes) c.randBytes(plaintext[4+len(messageBytes):]) - // The message is encrypted to an ephemeral key so that the sending - // client can choose not to store it and then cannot decrypt it once - // sent. - - // +---------------------+ +---... - // outerNonce | ephemeral DH public | innerNonce | message - // (24 bytes) | | (24 bytes) | - // +---------------------+ +---.... - - sealedLen := ephemeralBlockLen + nonceLen + len(plaintext) + box.Overhead - sealed := make([]byte, sealedLen) - var outerNonce [24]byte - c.randBytes(outerNonce[:]) - copy(sealed, outerNonce[:]) - x := sealed[nonceLen:] + var sealed []byte + if to.ratchet != nil { + sealed = to.ratchet.Encrypt(sealed, plaintext) + } else { + // The message is encrypted to an ephemeral key so that the sending + // client can choose not to store it and then cannot decrypt it once + // sent. + + // +---------------------+ +---... + // outerNonce | ephemeral DH public | innerNonce | message + // (24 bytes) | | (24 bytes) | + // +---------------------+ +---.... + + sealedLen := ephemeralBlockLen + nonceLen + len(plaintext) + box.Overhead + sealed = make([]byte, sealedLen) + var outerNonce [24]byte + c.randBytes(outerNonce[:]) + copy(sealed, outerNonce[:]) + x := sealed[nonceLen:] + + public, private, err := box.GenerateKey(c.rand) + if err != nil { + return err + } + box.Seal(x[:0], public[:], &outerNonce, &to.theirCurrentDHPublic, &to.lastDHPrivate) + x = x[len(public)+box.Overhead:] - public, private, err := box.GenerateKey(c.rand) - if err != nil { - return err + var innerNonce [24]byte + c.randBytes(innerNonce[:]) + copy(x, innerNonce[:]) + x = x[nonceLen:] + box.Seal(x[:0], plaintext, &innerNonce, &to.theirCurrentDHPublic, private) } - box.Seal(x[:0], public[:], &outerNonce, &to.theirCurrentDHPublic, &to.lastDHPrivate) - x = x[len(public)+box.Overhead:] - - var innerNonce [24]byte - c.randBytes(innerNonce[:]) - copy(x, innerNonce[:]) - x = x[nonceLen:] - box.Seal(x[:0], plaintext, &innerNonce, &to.theirCurrentDHPublic, private) sha := sha256.New() sha.Write(sealed) @@ -205,15 +214,27 @@ func (c *guiClient) revoke(to *Contact) { c.outbox = append(c.outbox, out) } -func decryptMessage(sealed []byte, nonce *[24]byte, from *Contact) ([]byte, bool) { - // The message starts with an ephemeral block, the nonce of which has - // already been split off. See the commends in send. +func decryptMessage(sealed []byte, from *Contact) ([]byte, bool) { + if from.ratchet != nil { + plaintext, err := from.ratchet.Decrypt(sealed) + if err != nil { + return nil, false + } + return plaintext, true + } + + var nonce [24]byte + if len(sealed) < len(nonce) { + return nil, false + } + copy(nonce[:], sealed) + sealed = sealed[24:] headerLen := ephemeralBlockLen - len(nonce) if len(sealed) < headerLen { return nil, false } - publicBytes, ok := decryptMessageInner(sealed[:headerLen], nonce, from) + publicBytes, ok := decryptMessageInner(sealed[:headerLen], &nonce, from) if !ok || len(publicBytes) != 32 { return nil, false } @@ -376,10 +397,7 @@ func (c *client) unsealMessage(inboxMsg *InboxMessage, from *Contact) bool { } sealed := inboxMsg.sealed - var nonce [24]byte - copy(nonce[:], sealed) - sealed = sealed[24:] - plaintext, ok := decryptMessage(sealed, &nonce, from) + plaintext, ok := decryptMessage(sealed, from) if !ok { c.log.Errorf("Failed to decrypt message from %s", from.name) @@ -405,11 +423,6 @@ func (c *client) unsealMessage(inboxMsg *InboxMessage, from *Contact) bool { return false } - if l := len(msg.MyNextDh); l != len(from.theirCurrentDHPublic) { - c.log.Errorf("Message from %s with bad DH length %d", from.name, l) - return false - } - // Check for duplicate message. for _, candidate := range c.inbox { if candidate.from == from.id && @@ -421,10 +434,17 @@ func (c *client) unsealMessage(inboxMsg *InboxMessage, from *Contact) bool { } } - if !bytes.Equal(from.theirCurrentDHPublic[:], msg.MyNextDh) { - // We have a new DH value from them. - copy(from.theirLastDHPublic[:], from.theirCurrentDHPublic[:]) - copy(from.theirCurrentDHPublic[:], msg.MyNextDh) + if from.ratchet == nil { + if l := len(msg.MyNextDh); l != len(from.theirCurrentDHPublic) { + c.log.Errorf("Message from %s with bad DH length %d", from.name, l) + return false + } + + if !bytes.Equal(from.theirCurrentDHPublic[:], msg.MyNextDh) { + // We have a new DH value from them. + copy(from.theirLastDHPublic[:], from.theirCurrentDHPublic[:]) + copy(from.theirCurrentDHPublic[:], msg.MyNextDh) + } } if msg.InReplyTo != nil { diff --git a/client/ratchet/ratchet.go b/client/ratchet/ratchet.go index a7c5264..2f089f1 100644 --- a/client/ratchet/ratchet.go +++ b/client/ratchet/ratchet.go @@ -12,6 +12,8 @@ import ( "code.google.com/p/go.crypto/curve25519" "code.google.com/p/go.crypto/nacl/secretbox" + "code.google.com/p/goprotobuf/proto" + "github.com/agl/pond/client/disk" pond "github.com/agl/pond/protos" ) @@ -138,6 +140,13 @@ var ( chainKeyStepLabel = []byte("chain key step") ) +// GetKXPrivateForTransition returns the DH private key used in the key +// exchange. This exists in order to support the transition to the new ratchet +// format. +func (r *Ratchet) GetKXPrivateForTransition() [32]byte { + return *r.kxPrivate0 +} + // CompleteKeyExchange takes a KeyExchange message from the other party and // establishes the ratchet. func (r *Ratchet) CompleteKeyExchange(kx *pond.KeyExchange) error { @@ -165,14 +174,14 @@ func (r *Ratchet) CompleteKeyExchange(kx *pond.KeyExchange) error { return errors.New("ratchet: peer echoed our own DH values back") } + var theirDH [32]byte + copy(theirDH[:], kx.Dh) + keyMaterial := make([]byte, 0, 32*5) var sharedKey [32]byte - curve25519.ScalarMult(&sharedKey, r.MyIdentityPrivate, r.TheirIdentityPublic) + curve25519.ScalarMult(&sharedKey, r.kxPrivate0, &theirDH) keyMaterial = append(keyMaterial, sharedKey[:]...) - var theirDH [32]byte - copy(theirDH[:], kx.Dh) - if amAlice { curve25519.ScalarMult(&sharedKey, r.MyIdentityPrivate, &theirDH) keyMaterial = append(keyMaterial, sharedKey[:]...) @@ -473,3 +482,112 @@ func (r *Ratchet) Decrypt(ciphertext []byte) ([]byte, error) { return msg, nil } + +func dup(key *[32]byte) []byte { + if key == nil { + return nil + } + + ret := make([]byte, 32) + copy(ret, key[:]) + return ret +} + +func (r *Ratchet) Marshal(now time.Time, lifetime time.Duration) *disk.RatchetState { + s := &disk.RatchetState{ + RootKey: dup(&r.rootKey), + SendHeaderKey: dup(&r.sendHeaderKey), + RecvHeaderKey: dup(&r.recvHeaderKey), + NextSendHeaderKey: dup(&r.nextSendHeaderKey), + NextRecvHeaderKey: dup(&r.nextRecvHeaderKey), + SendChainKey: dup(&r.sendChainKey), + RecvChainKey: dup(&r.recvChainKey), + SendRatchetPrivate: dup(&r.sendRatchetPrivate), + RecvRatchetPublic: dup(&r.recvRatchetPublic), + SendCount: proto.Uint32(r.sendCount), + RecvCount: proto.Uint32(r.recvCount), + PrevSendCount: proto.Uint32(r.prevSendCount), + Ratchet: proto.Bool(r.ratchet), + Private0: dup(r.kxPrivate0), + Private1: dup(r.kxPrivate1), + } + + for headerKey, messageKeys := range r.saved { + keys := make([]*disk.RatchetState_SavedKeys_MessageKey, 0, len(messageKeys)) + for messageNum, savedKey := range messageKeys { + if now.Sub(savedKey.timestamp) > lifetime { + continue + } + keys = append(keys, &disk.RatchetState_SavedKeys_MessageKey{ + Num: proto.Uint32(messageNum), + Key: dup(&savedKey.key), + CreationTime: proto.Int64(savedKey.timestamp.Unix()), + }) + } + s.SavedKeys = append(s.SavedKeys, &disk.RatchetState_SavedKeys{ + HeaderKey: dup(&headerKey), + MessageKeys: keys, + }) + } + + return s +} + +func unmarshalKey(dst *[32]byte, src []byte) bool { + if len(src) != 32 { + return false + } + copy(dst[:], src) + return true +} + +var badSerialisedKeyLengthErr = errors.New("ratchet: bad serialised key length") + +func (r *Ratchet) Unmarshal(s *disk.RatchetState) error { + if !unmarshalKey(&r.rootKey, s.RootKey) || + !unmarshalKey(&r.sendHeaderKey, s.SendHeaderKey) || + !unmarshalKey(&r.recvHeaderKey, s.RecvHeaderKey) || + !unmarshalKey(&r.nextSendHeaderKey, s.NextSendHeaderKey) || + !unmarshalKey(&r.nextRecvHeaderKey, s.NextRecvHeaderKey) || + !unmarshalKey(&r.sendChainKey, s.SendChainKey) || + !unmarshalKey(&r.recvChainKey, s.RecvChainKey) || + !unmarshalKey(&r.sendRatchetPrivate, s.SendRatchetPrivate) || + !unmarshalKey(&r.recvRatchetPublic, s.RecvRatchetPublic) { + return badSerialisedKeyLengthErr + } + + r.sendCount = *s.SendCount + r.recvCount = *s.RecvCount + r.prevSendCount = *s.PrevSendCount + r.ratchet = *s.Ratchet + + if len(s.Private0) > 0 { + if !unmarshalKey(r.kxPrivate0, s.Private0) || + !unmarshalKey(r.kxPrivate1, s.Private1) { + return badSerialisedKeyLengthErr + } + } else { + r.kxPrivate0 = nil + r.kxPrivate1 = nil + } + + for _, saved := range s.SavedKeys { + var headerKey [32]byte + if !unmarshalKey(&headerKey, saved.HeaderKey) { + return badSerialisedKeyLengthErr + } + messageKeys := make(map[uint32]savedKey) + for _, messageKey := range saved.MessageKeys { + var savedKey savedKey + if !unmarshalKey(&savedKey.key, messageKey.Key) { + return badSerialisedKeyLengthErr + } + savedKey.timestamp = time.Unix(messageKey.GetCreationTime(), 0) + messageKeys[messageKey.GetNum()] = savedKey + } + + r.saved[headerKey] = messageKeys + } + + return nil +} diff --git a/client/ratchet/ratchet_test.go b/client/ratchet/ratchet_test.go index 4546e71..c8c04cc 100644 --- a/client/ratchet/ratchet_test.go +++ b/client/ratchet/ratchet_test.go @@ -5,11 +5,17 @@ import ( "crypto/rand" "io" "testing" + "time" "code.google.com/p/go.crypto/curve25519" pond "github.com/agl/pond/protos" ) +func nowFunc() time.Time { + var t time.Time + return t +} + func pairedRatchet() (a, b *Ratchet) { var privA, pubA, privB, pubB [32]byte io.ReadFull(rand.Reader, privA[:]) @@ -25,6 +31,8 @@ func pairedRatchet() (a, b *Ratchet) { io.ReadFull(rand.Reader, bSigningPublic[:]) a, b = New(rand.Reader), New(rand.Reader) + a.Now = nowFunc + b.Now = nowFunc a.MyIdentityPrivate = &privA b.MyIdentityPrivate = &privB a.TheirIdentityPublic = &pubB @@ -87,6 +95,22 @@ const ( delay ) +func reinitRatchet(t *testing.T, r *Ratchet) *Ratchet { + state := r.Marshal(nowFunc(), 1 * time.Hour) + newR := New(rand.Reader) + newR.Now = nowFunc + newR.MyIdentityPrivate = r.MyIdentityPrivate + newR.TheirIdentityPublic = r.TheirIdentityPublic + newR.MySigningPublic = r.MySigningPublic + newR.TheirSigningPublic = r.TheirSigningPublic + if err := newR.Unmarshal(state); err != nil { + t.Fatalf("Failed to unmarshal: %s", err) + } + + return newR + +} + func testScript(t *testing.T, script []scriptAction) { type delayedMessage struct { msg []byte @@ -143,6 +167,9 @@ func testScript(t *testing.T, script []scriptAction) { t.Fatalf("#%d: bad message: got %x, not %x", i, result, delayed.msg) } } + + a = reinitRatchet(t, a) + b = reinitRatchet(t, b) } } diff --git a/protos/const.go b/protos/const.go index 13977dd..bfc5e5a 100644 --- a/protos/const.go +++ b/protos/const.go @@ -1,6 +1,5 @@ package protos -import "code.google.com/p/go.crypto/nacl/box" import "code.google.com/p/go.crypto/nacl/secretbox" // TransportSize is the number of bytes that all payloads are padded to before @@ -14,9 +13,15 @@ const MessageOverhead = 512 // MaxSerializedMessage is the maximum size of the serialized Message protobuf. // The message will end up looking like this: -// [nonce - 24 bytes ] -| -// [box.Overhead - 16 bytes] -| | // [length - 4 bytes ] | NaCl box | Message that server sees. -// [serialized message ] | | -// [padding ] -| -| -const MaxSerializedMessage = TransportSize - box.Overhead - MessageOverhead - 24 - 4 +// [nonce - 24 bytes ] +// +// [secretbox.Overhead - 16 bytes] +// [message count - 4 bytes ] +// [prev message count - 4 bytes ] +// [ratchet public - 32 bytes ] +// [inner nonce - 32 bytes ] +// +// [secretbox.Overhead - 16 bytes] +// [serialized message ] +const MaxSerializedMessage = TransportSize - (secretbox.Overhead + 4 + 4 + 32 + 24) - secretbox.Overhead - MessageOverhead diff --git a/protos/pond.pb.go b/protos/pond.pb.go index fc9a077..1cdf54a 100644 --- a/protos/pond.pb.go +++ b/protos/pond.pb.go @@ -660,7 +660,7 @@ type Message struct { Time *int64 `protobuf:"varint,2,req,name=time" json:"time,omitempty"` Body []byte `protobuf:"bytes,3,req,name=body" json:"body,omitempty"` BodyEncoding *Message_Encoding `protobuf:"varint,4,opt,name=body_encoding,enum=protos.Message_Encoding" json:"body_encoding,omitempty"` - MyNextDh []byte `protobuf:"bytes,5,req,name=my_next_dh" json:"my_next_dh,omitempty"` + MyNextDh []byte `protobuf:"bytes,5,opt,name=my_next_dh" json:"my_next_dh,omitempty"` InReplyTo *uint64 `protobuf:"varint,6,opt,name=in_reply_to" json:"in_reply_to,omitempty"` Files []*Message_Attachment `protobuf:"bytes,7,rep,name=files" json:"files,omitempty"` DetachedFiles []*Message_Detachment `protobuf:"bytes,8,rep,name=detached_files" json:"detached_files,omitempty"` diff --git a/protos/pond.proto b/protos/pond.proto index a94e88d..804780c 100644 --- a/protos/pond.proto +++ b/protos/pond.proto @@ -205,7 +205,7 @@ message Message { } optional Encoding body_encoding = 4; // my_next_dh contains a Curve25519 public value for future messages. - required bytes my_next_dh = 5; + optional bytes my_next_dh = 5; // in_reply_to, if set, contains the |id| value of a previous message // sent by the recipient. optional uint64 in_reply_to = 6;