# Breaking substitution cyphers with MCMC

In [None]:
# Import all the usual things
import math
import random
import numpy as np

# Import tools for reading files and processing strings
import string
from pathlib import Path 

# Also get tools to read the dictionary
# of probabilities back.
import json

## Preliminaries

### Read the dictionary of probabilities

The lines below read a dictionary, `logPairProbs`, that contains the logs of the probabilities $p(a,b)$. After these lines have been run:
   * `logProbPairs['t']` will be a dictionary whose keys are symbols that might follow a `'t'` in the text;
   * `logProbPairs['t']['h']` will be $\log(p(\mbox{'t', 'h'}))$, the log of the probability that a `'t'` is followed by an `'h'`.

In [None]:
# Read the dictionary of log(p(a,b)).
probFileObj = open("LogPairProbDict.json", "r") # Open the file
jsonFromFile = probFileObj.read() # Read its contents
logPairProbs = json.loads(jsonFromFile) # Use the contents to construct the dictionary
probFileObj.close() # Tidy up.

# To make sure all is well, look at p('t','h')
print( math.exp(logPairProbs['t']['h']) )

### Define a raft of utilities

Many of these functions are defined in `EstimatePairProbs.ipynb` and `SubstitutionCyphers.ipynb`, but it's convenient to have them here too.

#### Given a dictionary of $\log(p(a,b))$, extract the alphabet of allowed characters.

The alphabet provides the keys of the dictionary.

In [None]:
def extractAlphabet( probDict ):
    # Examine the input to get the alphabet of allowed characters
    myAlphabet = list( probDict.keys() )
    myAlphabetStr = ''.join(myAlphabet)
    return( myAlphabetStr )

# Do a test
extractAlphabet( logPairProbs )

#### Standardise a text

In [None]:
def standardiseText( rawText, allowedChars, replacementChar=' ' ):
    # Make all the characters lower case
    rawText = rawText.lower()

    # Replace any characters that aren't part of our list
    # of allowed characters with the replacement character
    standardisedText = ""
    for char in rawText:
        if allowedChars.find(char) == -1:
            # char isn't one of the allowed ones
            standardisedText = standardisedText + replacementChar
        else:
            standardisedText = standardisedText + char
            
    return( standardisedText )

# Do a small test
testText = "Where would heavy metal be without the ümlaut?"
standardiseText( testText, ' abcdefghijklmnopqrstuvwxyz', '*' )

#### Encyphering a text and computing $S^{-1}$ given $S$.

The two main things we want to do with a cypher are (a) encrypt a message and (b) figure out how to decrypt a message if we have the table of substitutions used to encypher it.

In [None]:
# Given a message and an encryption or decription dictionary, apply it.
def applyCypher( msg, cypherDict ):
    result = "" ;
    for char in msg:
        result += cypherDict[char]
        
    return( result )

# Testing: result should be 'cbabc'
applyCypher( "abcba", {'a':'c', 'b':'b', 'c':'a'} )

Given a dictionary representing a substitutions cypher $S$, returns one that represents $S^{-1}$.

In [None]:
# Given an encryption dictionary, find the decryption dictionary
def invertCypher( cypherDict ):
    inverseDict = dict.fromkeys( cypherDict.keys() )
    for plaintextChar in cypherDict.keys():
        cyphertextChar = cypherDict[plaintextChar]
        inverseDict[cyphertextChar] = plaintextChar
        
    return( inverseDict )

# Testing: result should be {'a': 'c', 'b': 'a', 'c': 'b'}
invertCypher( {'a':'b', 'b':'c', 'c':'a'} )

#### Representing cyphers

We can represent a substiution cypher in at least two ways. Perhaps the most natural approach in Python is to make a dictionary arranged so that `cypherDict[plaintextChar] = cyphertextChar`. An alternative is to arrange the keys of such a dictionary in some standard order and then just list the values in a string.

In [None]:
def cypherStrToDict( cypherStr ):
    alphabet = sorted( cypherStr )
    cypherDict = dict.fromkeys( alphabet )
    for j in range(len(alphabet)):
        cypherDict[alphabet[j]] = cypherStr[j]
        
    return( cypherDict )

# Testing: result should be {'a': 'b', 'b': 'c', 'c': 'a'}
cypherStrToDict( 'bca' )

In [None]:
def cypherDictToStr( cypherDict ):
    return( ''.join(list(cypherDict.values())) )

# Should return the test string
testStr = 'bca'
cypherDictToStr( cypherStrToDict(testStr) )

#### Generating random cyphers

Finally, here is a tool to generate random substitution cyphers.

In [None]:
# Generate a random cypher for a given alphabet
def randomCypher( alphabetStr ):
    # Put the alphabet into standard order
    alphabet = sorted( alphabetStr )
    
    # Generate a shuffled version of the alphabet
    scrambledAlphabet = alphabet.copy() # make a copy
    random.shuffle( scrambledAlphabet ) # shuffle it
     
    # Assemble the dictionary of substitutions
    cypher = dict.fromkeys( alphabet, '' )
    for j in range(len(alphabet)):
        cypher[alphabet[j]] = scrambledAlphabet[j]
        
    return( cypher )

# Do a small test
smallAlphabet = 'abcdefg'
randomCypher( smallAlphabet )

## The MCMC code

To do MCMC, we need to be able to evaluate a log-likelihood.

In [None]:
def logLikelihood( msg, logPairProbs ):
    loglike = 0.0 ;
    # Your code here.
    # Implement Eqn. (5) from the assignment
            
    return( loglike )

The following function does most of the work.

In [None]:
def decypherWithMCMC( cyphertext, logPairProbs, nSamples, burnIn ):
    # Examine the input to get the alphabet of allowed characters
    myAlphabet = extractAlphabet( logPairProbs )
    
    # Step (1) Initilaise the MCMC run by choosing a decryption
    # key at random. This is equivalent to sampling from a uniform prior.
    crntCypherDict = randomCypher( myAlphabet )
    
    # Step (2) Decrypt the cyphertext using crntCypherDict
    crntDecryptDict = invertCypher( crntCypherDict )
    crntPlaintext = applyCypher( cyphertext, crntDecryptDict )
    
    # Step (3) Compute the log-likelihood
    crntLoglike = logLikelihood( crntPlaintext, logPairProbs )
    
    # Do the sampling
    nProposed = 0
    nAccepted = 0 
    sampleNum = 0
    samples = [''] * nSamples # initially empty
    while sampleNum < nSamples:
        # Step (4): generate a proposal
        
        # Your code here:
        # Choose a pair of symbols from the alphabet and make
        # a new cypher that swaps the characters assigned to the pair
        
        
        # Step (5) Get the plaintext implied by the proposed cypher
        # Your code here 
        
        # Step (6) Compute the log-likelihood
        proposedLoglike = logLikelihood( proposedPlaintext, logPairProbs )
        nProposed += 1
        
        # Step (7) The Metropolis-Hastings acceptance step
        # Your code here:
        # Compute the MH-acceptance ratio, then 
        # set a boolean variable the answers the question
        # "Should we accept this proposal?"
        
        if( acceptProposal ):
            crntCypherDict = proposedCypherDict
            crntLoglike = proposedLoglike
            nAccepted += 1

        if( nProposed >= burnIn ):
            samples[sampleNum] = cypherDictToStr( crntCypherDict )
            sampleNum += 1
       
    # Report the acceptance ratio and
    # return the samples.
    print( nAccepted / nProposed )
    return( samples )

### Try with a real text

In [None]:
# Extract the alphabet of allowed characters from logPairProbs
lppAlphabet = extractAlphabet( logPairProbs )

# Read and standardise a plaintext
rawPlaintext = Path('SamplePlaintext.txt').read_text()
plaintext = standardiseText( rawPlaintext, lppAlphabet )

# Generate a random cypher and apply it
myCypher = randomCypher( lppAlphabet )
cyphertext = applyCypher( plaintext, myCypher )

# Do MCMC
burnIn = 8000
nSamples = 10 
samples = decypherWithMCMC( cyphertext, logPairProbs, nSamples, burnIn )

In [None]:
# See how well it worked
decryptedFragment = [''] * nSamples
for j in range(nSamples):
    cypherStr = samples[j]
    cypherDict = cypherStrToDict( cypherStr )
    decryptDict = invertCypher( cypherDict )
    crntPlaintext = applyCypher( cyphertext, decryptDict )
    decryptedFragment[j] = crntPlaintext[0:50]
    
decryptedFragment