From c5496839c9a1b8990400ed6d03178368fdb54223 Mon Sep 17 00:00:00 2001 From: Philipp Winter Date: Sat, 8 Mar 2014 16:42:31 +0100 Subject: [PATCH] When authenticating, also test epoch boundaries. On occasion, a client's or a server's epoch might already have increased whereas the epoch of the other party didn't. This is a benign event and there is no reason to fail authentication because of this. As a result, as a server, we now also test boundary values, i.e., epoch - 1, epoch, epoch + 1. --- scramblesuit.py | 22 ++++++++++++------- uniformdh.py | 22 ++++++++++++------- unittests.py | 57 +++++++++++++++++++++++++++++++++++++++++++++++++ util.py | 10 +++++++++ 4 files changed, 95 insertions(+), 16 deletions(-) diff --git a/scramblesuit.py b/scramblesuit.py index 9262b34..c737902 100644 --- a/scramblesuit.py +++ b/scramblesuit.py @@ -388,14 +388,20 @@ def receiveTicket( self, data ): existingHMAC = potentialTicket[index + const.MARK_LENGTH: index + const.MARK_LENGTH + const.HMAC_SHA256_128_LENGTH] - myHMAC = mycrypto.HMAC_SHA256_128(self.recvHMAC, - potentialTicket[0: - index + const.MARK_LENGTH] + - util.getEpoch()) - - if not util.isValidHMAC(myHMAC, existingHMAC, self.recvHMAC): - log.warning("The HMAC is invalid: `%s' vs. `%s'." % - (myHMAC.encode('hex'), existingHMAC.encode('hex'))) + authenticated = False + for epoch in util.expandedEpoch(): + myHMAC = mycrypto.HMAC_SHA256_128(self.recvHMAC, + potentialTicket[0:index + \ + const.MARK_LENGTH] + epoch) + + if util.isValidHMAC(myHMAC, existingHMAC, self.recvHMAC): + authenticated = True + break + + log.debug("HMAC invalid. Trying next epoch value.") + + if not authenticated: + log.warning("Could not verify the authentication message's HMAC.") return False # Do nothing if the ticket is replayed. Immediately closing the diff --git a/uniformdh.py b/uniformdh.py index 1b59575..dd16070 100644 --- a/uniformdh.py +++ b/uniformdh.py @@ -120,19 +120,25 @@ def extractPublicKey( self, data, srvState=None ): if not index: return False - self.echoEpoch = util.getEpoch() - # Now that we know where the authenticating HMAC is: verify it. hmacStart = index + const.MARK_LENGTH existingHMAC = handshake[hmacStart: (hmacStart + const.HMAC_SHA256_128_LENGTH)] - myHMAC = mycrypto.HMAC_SHA256_128(self.sharedSecret, - handshake[0 : hmacStart] + - self.echoEpoch) - if not util.isValidHMAC(myHMAC, existingHMAC, self.sharedSecret): - log.warning("The HMAC is invalid: `%s' vs. `%s'." % - (myHMAC.encode('hex'), existingHMAC.encode('hex'))) + authenticated = False + for epoch in util.expandedEpoch(): + myHMAC = mycrypto.HMAC_SHA256_128(self.sharedSecret, + handshake[0 : hmacStart] + epoch) + + if util.isValidHMAC(myHMAC, existingHMAC, self.sharedSecret): + self.echoEpoch = epoch + authenticated = True + break + + log.debug("HMAC invalid. Trying next epoch value.") + + if not authenticated: + log.warning("Could not verify the authentication message's HMAC.") return False # Do nothing if the ticket is replayed. Immediately closing the diff --git a/unittests.py b/unittests.py index 24feb7c..b3ca388 100644 --- a/unittests.py +++ b/unittests.py @@ -9,6 +9,8 @@ import base64 import shutil import tempfile +import ticket +import state import message @@ -160,6 +162,29 @@ def test3_invalidHMAC( self ): self.failIf(self.udh.receivePublicKey(buf, lambda x: x) == True) + def test4_extractPublicKey( self ): + + # Create UniformDH authentication message. + sharedSecret = "A" * const.SHARED_SECRET_LENGTH + + realEpoch = util.getEpoch + + # Try three valid and one invalid epoch value. + for epoch in util.expandedEpoch() + ["000000"]: + udh = uniformdh.new(sharedSecret, True) + + util.getEpoch = lambda: epoch + authMsg = udh.createHandshake() + util.getEpoch = realEpoch + + buf = obfs_buf.Buffer() + buf.write(authMsg) + + if epoch == "000000": + self.assertFalse(udh.extractPublicKey(buf)) + else: + self.assertTrue(udh.extractPublicKey(buf)) + class UtilTest( unittest.TestCase ): def test1_isValidHMAC( self ): @@ -308,6 +333,38 @@ def test4_ProtocolMessage( self ): self.assertRaises(base.PluggableTransportError, message.ProtocolMessage, "1", paddingLen=const.MPU) +class TicketTest( unittest.TestCase ): + + def test1_authentication( self ): + srvState = state.State() + srvState.genState() + + ss = scramblesuit.ScrambleSuitTransport() + ss.srvState = srvState + + realEpoch = util.getEpoch + + # Try three valid and one invalid epoch value. + for epoch in util.expandedEpoch() + ["000000"]: + + util.getEpoch = lambda: epoch + + # Prepare ticket message. + blurb = ticket.issueTicketAndKey(srvState) + rawTicket = blurb[const.MASTER_KEY_LENGTH:] + masterKey = blurb[:const.MASTER_KEY_LENGTH] + ss.deriveSecrets(masterKey) + ticketMsg = ticket.createTicketMessage(rawTicket, ss.recvHMAC) + + util.getEpoch = realEpoch + + buf = obfs_buf.Buffer() + buf.write(ticketMsg) + + if epoch == "000000": + self.assertFalse(ss.receiveTicket(buf)) + else: + self.assertTrue(ss.receiveTicket(buf)) if __name__ == '__main__': # Disable all logging as it would yield plenty of warning and error diff --git a/util.py b/util.py index bbb6c6a..f22c9b9 100644 --- a/util.py +++ b/util.py @@ -106,6 +106,16 @@ def getEpoch( ): return str(int(time.time()) / const.EPOCH_GRANULARITY) +def expandedEpoch( ): + """ + Return [epoch, epoch-1, epoch+1]. + """ + + epoch = int(getEpoch()) + + return [str(epoch), str(epoch - 1), str(epoch + 1)] + + def writeToFile( data, fileName ): """ Writes the given `data' to the file specified by `fileName'.