Skip to content

Commit

Permalink
Add 'DecryptAndGetKey' method for Session Ciphers. Fix bug where decr…
Browse files Browse the repository at this point in the history
…ypt would mutate whisper message.
  • Loading branch information
ShadowApex committed Apr 3, 2017
1 parent 0b0e62a commit 8ee59a2
Show file tree
Hide file tree
Showing 3 changed files with 161 additions and 23 deletions.
66 changes: 43 additions & 23 deletions session/SessionCipher.go
Expand Up @@ -10,6 +10,7 @@ import (
"github.com/RadicalApp/libsignal-protocol-go/protocol"
"github.com/RadicalApp/libsignal-protocol-go/state/record"
"github.com/RadicalApp/libsignal-protocol-go/state/store"
"github.com/RadicalApp/libsignal-protocol-go/util/bytehelper"
"strconv"
)

Expand All @@ -33,7 +34,7 @@ func NewCipher(builder *Builder, remoteAddress *protocol.SignalAddress) *Cipher
func NewCipherFromSession(session *record.Session, remoteAddress *protocol.SignalAddress,
sessionStore store.Session, preKeyStore store.PreKey,
preKeyMessageSerializer protocol.PreKeySignalMessageSerializer,
signalMessageSerializer protocol.SignalMessageSerializer) *Cipher{
signalMessageSerializer protocol.SignalMessageSerializer) *Cipher {
cipher := &Cipher{
sessionStore: sessionStore,
preKeyMessageSerializer: preKeyMessageSerializer,
Expand Down Expand Up @@ -123,39 +124,60 @@ func (d *Cipher) Encrypt(plaintext []byte) (protocol.CiphertextMessage, error) {
// Decrypt decrypts the given message using an existing session that
// is stored in the session store.
func (d *Cipher) Decrypt(ciphertextMessage *protocol.SignalMessage) ([]byte, error) {
plaintext, _, err := d.DecryptAndGetKey(ciphertextMessage)

return plaintext, err
}

// DecryptAndGetKey decrypts the given message using an existing session that
// is stored in the session store and returns the message keys used for encryption.
func (d *Cipher) DecryptAndGetKey(ciphertextMessage *protocol.SignalMessage) ([]byte, *message.Keys, error) {
if !d.sessionStore.ContainsSession(d.remoteAddress) {
return nil, errors.New("No session for: " + d.remoteAddress.String())
return nil, nil, errors.New("No session for: " + d.remoteAddress.String())
}

// Load the session record from our session store and decrypt the message.
sessionRecord := d.sessionStore.LoadSession(d.remoteAddress)
plaintext, err := d.DecryptWithRecord(sessionRecord, ciphertextMessage)
plaintext, messageKeys, err := d.DecryptWithRecord(sessionRecord, ciphertextMessage)
if err != nil {
return nil, err
return nil, nil, err
}

// Store the session record in our session store.
d.sessionStore.StoreSession(d.remoteAddress, sessionRecord)

return plaintext, messageKeys, nil
}

// DecryptWithKey will decrypt the given message using the given symmetric key. This
// can be used when decrypting messages at a later time if the message key was saved.
func (d *Cipher) DecryptWithKey(ciphertextMessage *protocol.SignalMessage, key *message.Keys) ([]byte, error) {
logger.Debug("Decrypting ciphertext body: ", ciphertextMessage.Body())
plaintext, err := decrypt(key, ciphertextMessage.Body())
if err != nil {
logger.Error("Unable to get plain text from ciphertext: ", err)
return nil, err
}

return plaintext, nil
}

// DecryptWithRecord decrypts the given message using the given session record.
func (d *Cipher) DecryptWithRecord(sessionRecord *record.Session, ciphertext *protocol.SignalMessage) ([]byte, error) {
func (d *Cipher) DecryptWithRecord(sessionRecord *record.Session, ciphertext *protocol.SignalMessage) ([]byte, *message.Keys, error) {
logger.Debug("Decrypting ciphertext with record: ", sessionRecord)
previousStates := sessionRecord.PreviousSessionStates()
sessionState := sessionRecord.SessionState()

// Try and decrypt the message with the current session state.
plaintext, err := d.DecryptWithState(sessionState, ciphertext)
plaintext, messageKeys, err := d.DecryptWithState(sessionState, ciphertext)

// If we received an error using the current session state, loop
// through all previous states.
if err != nil {
logger.Warning(err)
for i, state := range previousStates {
// Try decrypting the message with previous states
plaintext, err = d.DecryptWithState(state, ciphertext)
plaintext, messageKeys, err = d.DecryptWithState(state, ciphertext)
if err != nil {
continue
}
Expand All @@ -164,31 +186,31 @@ func (d *Cipher) DecryptWithRecord(sessionRecord *record.Session, ciphertext *pr
previousStates = append(previousStates[:i], previousStates[i+1:]...)
sessionRecord.PromoteState(state)

return plaintext, nil
return plaintext, messageKeys, nil
}

return nil, errors.New("No valid sessions.")
return nil, nil, errors.New("No valid sessions.")
}

// If decryption was successful, set the session state and return the plain text.
sessionRecord.SetState(sessionState)

return plaintext, nil
return plaintext, messageKeys, nil
}

// DecryptWithState decrypts the given message with the given session state.
func (d *Cipher) DecryptWithState(sessionState *record.State, ciphertextMessage *protocol.SignalMessage) ([]byte, error) {
func (d *Cipher) DecryptWithState(sessionState *record.State, ciphertextMessage *protocol.SignalMessage) ([]byte, *message.Keys, error) {
logger.Debug("Decrypting ciphertext with session state: ", sessionState)
if !sessionState.HasSenderChain() {
err := "Uninitialized session!"
logger.Error("Unable to decrypt message with state: ", err)
return nil, errors.New(err)
return nil, nil, errors.New(err)
}

if ciphertextMessage.MessageVersion() != sessionState.Version() {
err := "Wrong message version!"
logger.Error("Unable to decrypt message with state: ", err)
return nil, errors.New(err)
return nil, nil, errors.New(err)
}

messageVersion := ciphertextMessage.MessageVersion()
Expand All @@ -197,31 +219,29 @@ func (d *Cipher) DecryptWithState(sessionState *record.State, ciphertextMessage
chainKey, chainCreateErr := getOrCreateChainKey(sessionState, theirEphemeral)
if chainCreateErr != nil {
logger.Error("Unable to get or create chain key: ", chainCreateErr)
return nil, chainCreateErr
return nil, nil, chainCreateErr
}

messageKeys, keysCreateErr := getOrCreateMessageKeys(sessionState, theirEphemeral, chainKey, counter)
if keysCreateErr != nil {
logger.Error("Unable to get or create message keys: ", keysCreateErr)
return nil, keysCreateErr
return nil, nil, keysCreateErr
}

err := ciphertextMessage.VerifyMac(messageVersion, sessionState.RemoteIdentityKey(), sessionState.LocalIdentityKey(), messageKeys.MacKey())
if err != nil {
logger.Error("Unable to verify ciphertext mac: ", err)
return nil, err
return nil, nil, err
}

logger.Debug("Decrypting ciphertext body: ", ciphertextMessage.Body())
plaintext, pErr := decrypt(messageKeys, ciphertextMessage.Body())
if pErr != nil {
logger.Error("Unable to get plain text from ciphertext: ", pErr)
return nil, pErr
plaintext, err := d.DecryptWithKey(ciphertextMessage, messageKeys)
if err != nil {
return nil, nil, err
}

sessionState.ClearUnackPreKeyMessage()

return plaintext, nil
return plaintext, messageKeys, nil
}

func getOrCreateMessageKeys(sessionState *record.State, theirEphemeral ecc.ECPublicKeyable,
Expand Down Expand Up @@ -296,7 +316,7 @@ func getOrCreateChainKey(sessionState *record.State, theirEphemeral ecc.ECPublic
// the plaintext bytes.
func decrypt(keys *message.Keys, body []byte) ([]byte, error) {
logger.Debug("Using cipherKey: ", keys.CipherKey())
return cipher.Decrypt(keys.Iv(), keys.CipherKey(), body)
return cipher.Decrypt(keys.Iv(), keys.CipherKey(), bytehelper.CopySlice(body))
}

// encrypt will use the given cipher, message keys, and plaintext bytes
Expand Down
110 changes: 110 additions & 0 deletions tests/saved_message_keys_test.go
@@ -0,0 +1,110 @@
package tests

import (
"github.com/RadicalApp/libsignal-protocol-go/keys/message"
"github.com/RadicalApp/libsignal-protocol-go/keys/prekey"
"github.com/RadicalApp/libsignal-protocol-go/logger"
"github.com/RadicalApp/libsignal-protocol-go/protocol"
"github.com/RadicalApp/libsignal-protocol-go/session"
"testing"
)

// TestSavedMessageKeys tests the ability to save message keys for use in
// decrypting messages in the future.
func TestSavedMessageKeys(t *testing.T) {

// Create a serializer object that will be used to encode/decode data.
serializer := newSerializer()

// Create our users who will talk to each other.
alice := newUser("Alice", 1, serializer)
bob := newUser("Bob", 2, serializer)

// Create a session builder to create a session between Alice -> Bob.
alice.buildSession(bob.address, serializer)
bob.buildSession(alice.address, serializer)

// Create a PreKeyBundle from Bob's prekey records and other
// data.
logger.Debug("Fetching Bob's prekey with ID: ", bob.preKeys[0].ID())
retrievedPreKey := prekey.NewBundle(
bob.registrationID,
bob.deviceID,
bob.preKeys[0].ID(),
bob.signedPreKey.ID(),
bob.preKeys[0].KeyPair().PublicKey(),
bob.signedPreKey.KeyPair().PublicKey(),
bob.signedPreKey.Signature(),
bob.identityKeyPair.PublicKey(),
)

// Process Bob's retrieved prekey to establish a session.
logger.Debug("Building sender's (Alice) session...")
err := alice.sessionBuilder.ProcessBundle(retrievedPreKey)
if err != nil {
logger.Error("Unable to process retrieved prekey bundle")
t.FailNow()
}

// Create a session cipher to encrypt messages to Bob.
plaintextMessage := []byte("Hello!")
logger.Info("Plaintext message: ", string(plaintextMessage))
sessionCipher := session.NewCipher(alice.sessionBuilder, bob.address)
message, err := sessionCipher.Encrypt(plaintextMessage)
if err != nil {
logger.Error("Unable to encrypt message: ", err)
t.FailNow()
}

logger.Info("Encrypted message: ", message)

///////////// RECEIVER SESSION CREATION ///////////////

// Emulate receiving the message as JSON over the network.
logger.Debug("Building message from bytes on Bob's end.")
receivedMessage, err := protocol.NewPreKeySignalMessageFromBytes(message.Serialize(), serializer.PreKeySignalMessage, serializer.SignalMessage)
if err != nil {
logger.Error("Unable to emulate receiving message as JSON: ", err)
t.FailNow()
}

// Create a session builder
logger.Debug("Building receiver's (Bob) session...")
unsignedPreKeyID, err := bob.sessionBuilder.Process(receivedMessage)
if err != nil {
logger.Error("Unable to process prekeysignal message: ", err)
t.FailNow()
}
logger.Debug("Got PreKeyID: ", unsignedPreKeyID)

// Try and decrypt the message and get the message key.
bobSessionCipher := session.NewCipher(bob.sessionBuilder, alice.address)
msg, key, err := bobSessionCipher.DecryptAndGetKey(receivedMessage.WhisperMessage())
if err != nil {
logger.Error("Unable to decrypt message: ", err)
t.FailNow()
}
logger.Info("Decrypted message: ", string(msg))
if string(msg) != string(plaintextMessage) {
logger.Error("Decrypted string does not match - Encrypted: ", string(plaintextMessage), " Decrypted: ", string(msg))
t.FailNow()
}

// Try using the message key to decrypt the message again.
logger.Info("Testing using saved message key to decrypt again.")
for i := 0; i < 10; i++ {
testDecryptingWithKey(bobSessionCipher, receivedMessage.WhisperMessage(), key, plaintextMessage, t)
}
}

func testDecryptingWithKey(cipher *session.Cipher, receivedMessage *protocol.SignalMessage, key *message.Keys, plaintextMessage []byte, t *testing.T) {
msg, err := cipher.DecryptWithKey(receivedMessage, key)
if err != nil {
t.FailNow()
}
logger.Info("Decrypted message: ", string(msg))
if string(msg) != string(plaintextMessage) {
logger.Error("Decrypted string does not match - Encrypted: ", string(plaintextMessage), " Decrypted: ", string(msg))
t.FailNow()
}
}
8 changes: 8 additions & 0 deletions util/bytehelper/ByteHelper.go
Expand Up @@ -87,3 +87,11 @@ func Bytes5ToInt64(bytes []byte, offset int) int64 {

return value
}

// CopySlice returns a copy of the given bytes.
func CopySlice(bytes []byte) []byte {
cp := make([]byte, len(bytes))
copy(cp, bytes)

return cp
}

0 comments on commit 8ee59a2

Please sign in to comment.