# Salsa

In [36]:
from typing import Tuple, List
from scripts.utils import *
from scripts.math import sum32

## Pad function

In [5]:
_c = [0x61707865, 0x3320646e, 0x79622d32, 0x6b206574]

def pad(s: int, j: int, n: int) -> int:
    """A padding function that combies a 256 bits seed s with a 64 bit counter j (j0, j1 of 32 bits) and a b4 bit nonce (n0, n1 of 32 bits) to form a 512 bit block denoted x0, ..., x15 of 32 bits
    
    Args:
        s: Seed (256 bits, 8 words of 32 bits)
        j: counter (64 bits)
        n: Nonce (64 bits)

    Returns:
        512 bits (16 words 32 bits)
    """

    # Convert to words
    s = get_words(s, 32, 8)
    j = get_words(j, 32, 2)
    n = get_words(n, 32, 2)

    return [
        _c[0], s[0], s[1], s[2],
        s[3], _c[1], n[0], n[1],
        j[0], j[1], _c[2], s[4],
        s[5], s[6], s[7], _c[3]
    ]


In [6]:
assert(
    pad(0x112233445566778899aabbccddeeff00,
        0x0123456789abcdef,
        0xfedcba9876543210) == 
        
        [0x61707865, 0x0, 0x0, 0x0, 
         0x0, 0x3320646e, 0xfedcba98, 0x76543210,
         0x1234567, 0x89abcdef, 0x79622d32, 0x11223344,
         0x55667788, 0x99aabbcc, 0xddeeff00, 0x6b206574]
), "Error"

In [7]:
assert(
    pad(0x47f515b1dd45f8d5aceea73b52971be21f7b4b3355a35fd6a2799898ed2f8c97,
    0x722d9d570ac23201,
    0xed539cd99e1d2f20) ==
    
    [0x61707865, 0x47f515b1, 0xdd45f8d5, 0xaceea73b,
    0x52971be2, 0x3320646e, 0xed539cd9, 0x9e1d2f20,
    0x722d9d57, 0x0ac23201, 0x79622d32, 0x1f7b4b33,
    0x55a35fd6, 0xa2799898, 0xed2f8c97, 0x6b206574]
), "Error"

## Public permutation

The public permutation is constructed by iterating a simple permutation a fixed number of times. 

### Quarterround function

- Input: 4-word sequence 
- Output: 4-word sequence 

The entire function is invertible.

In [8]:
def rot(w: int, r: int) -> int: 
    """Rotate lest for 32 bits

    Args:
        w: word to rotate
        r: I dont remember
    """

    mask = 0xffffffff
    return ((w << r) & mask) | (w >> (32 - r)) 

In [11]:
def QR(a: int, b: int, c: int, d: int) -> Tuple[int, int, int, int]: 
    # Quater round
    b = b ^ rot(sum32(a, d), 7)
    c = c ^ rot(sum32(b, a), 9)
    d = d ^ rot(sum32(c, b), 13)
    a = a ^ rot(sum32(d, c), 18)
    
    return a, b, c, d


In [12]:
# Test cases (from the Salsa20 specification)

assert(QR(0x00000000, 0x00000000, 0x00000000, 0x00000000)
== (0x00000000, 0x00000000, 0x00000000, 0x00000000)), "Faild test 1"
assert(QR(0x00000001, 0x00000000, 0x00000000, 0x00000000)
== (0x08008145, 0x00000080, 0x00010200, 0x20500000)), "Failed test 2"
assert(QR(0x00000000, 0x00000001, 0x00000000, 0x00000000)
== (0x88000100, 0x00000001, 0x00000200, 0x00402000)), "Failed test 3"
assert(QR(0x00000000, 0x00000000, 0x00000001, 0x00000000)
== (0x80040000, 0x00000000, 0x00000001, 0x00002000)), "Failed test 4"
assert(QR(0x00000000, 0x00000000, 0x00000000, 0x00000001)
== (0x00048044, 0x00000080, 0x00010000, 0x20100001)), "Failed test 5"
assert(QR(0xe7e8c006, 0xc4f9417d, 0x6479b4b2, 0x68c67137)
== (0xe876d72b, 0x9361dfd5, 0xf1460244, 0x948541a3)), "Failed test 6"
assert(QR(0xd3917c5b, 0x55f1c407, 0x52a58a7a, 0x8f887a3b)
== (0x3e2f308c, 0xd90a8f36, 0x6ab2a923, 0x2883524c)), "Failed test 7"

In [15]:
assert(
    QR(0xc2619378, 0xecdaec96, 0xe62bd0c8, 0x2b61be56) ==
    (0x21158c0a, 0x0d720be0, 0x41156157, 0xc6c75786)
), "Error"

### Round Function

In [16]:
def round(i: List[int]):
    # Doble round
        
    # Odd round (column round)
    i[0], i[4], i[8], i[12] = QR(i[0], i[4], i[8], i[12]) 
    i[5], i[9], i[13], i[1] = QR(i[5], i[9], i[13], i[1]) 
    i[10], i[14], i[2], i[6] = QR(i[10], i[14], i[2], i[6]) 
    i[15], i[3], i[7], i[11] = QR(i[15], i[3], i[7], i[11]) 

    # Even round (row round)
    i[0], i[1], i[2], i[3] = QR(i[0], i[1], i[2], i[3]) 
    i[5], i[6], i[7], i[4] = QR(i[5], i[6], i[7], i[4]) 
    i[10], i[11], i[8], i[9] = QR(i[10], i[11], i[8], i[9])
    i[15], i[12], i[13], i[14] = QR(i[15], i[12], i[13], i[14])

    return i

In [80]:
assert(round([0x00000001, 0x00000000, 0x00000000, 0x00000000,
0x00000000, 0x00000000, 0x00000000, 0x00000000,
0x00000000, 0x00000000, 0x00000000, 0x00000000,
0x00000000, 0x00000000, 0x00000000, 0x00000000]) == [0x8186a22d, 0x0040a284, 0x82479210, 0x06929051,
0x08000090, 0x02402200, 0x00004000, 0x00800000,
0x00010200, 0x20400000, 0x08008104, 0x00000000,
0x20500000, 0xa0000040, 0x0008180a, 0x612a8020]), "Failed test 1"

assert(round([0xde501066, 0x6f9eb8f7, 0xe4fbbd9b, 0x454e3f57,
0xb75540d3, 0x43e93a4c, 0x3a6f2aa0, 0x726d6b36,
0x9243f484, 0x9145d1e8, 0x4fa9d247, 0xdc8dee11,
0x054bf545, 0x254dd653, 0xd9421b6d, 0x67b276c1]) == [0xccaaf672, 0x23d960f7, 0x9153e63a, 0xcd9a60d0,
0x50440492, 0xf07cad19, 0xae344aa0, 0xdf4cfdfc,
0xca531c29, 0x8e7943db, 0xac1680cd, 0xd503ca00,
0xa74b2ad6, 0xbc331c5c, 0x1dda24c7, 0xee928277]), "Failed test 2"

### The perm function

In [110]:

def perm(x: int, ROUNDS: int = 20): 
    """Permutation 

    Args:
        x: 512 bits (16 words 32 bits)
        ROUNDS: Number of rounds
            - Default 20 (salsa20/20)
            - 8 (salsa20/8)
            - 12 (salsa20/12)
            
    Returns:
        512 bits (16 words 32 bits)
    """
    
    assert(ROUNDS in [8, 12, 20]), "Invalid number of rounds"

    # i = get_words(x, 32, 16)
    i = x.copy()
    
    for j in range(0, ROUNDS, 2): 
        round(i)

    return i


In [137]:
s = 0x47f515b1dd45f8d5aceea73b52971be21f7b4b3355a35fd6a2799898ed2f8c97
j = 0x722d9d570ac23201
n = 0xed539cd99e1d2f20

_pad = pad(s, j, n)

assert(from_words(perm(_pad), 32) == 0x4ae1c9a7e960b635dc60a70e05f3a06b6d5333b853e0b60d7fe901e08289149820c71b7f7bf63cd69222987510bb60608551ec51c1e23b31da929437ccb2cb58)

## Psudorandom generator

In [34]:
L = 1 # 
def G(s: int, n: int):
    r = [0]*L

    for j in range(L):
        h = pad(s, j, n) # 512 bits
        pi = perm(h) # 512 bits

        r[j] = [0]*16
        for i in range(16):
           r[j] = get_words(pi, 32) + get_words(h, 32)
           
    return r

## Example 

In [None]:
k = G(0x112233445566778899aabbccddeeff00, 0x0123456789abcdef)[0]
m = text_to_int("Hello word!")
c = m ^ k
assert(c ^ k == m), "Error"

## References
- Class notes 
- [Salsa20 specification - Daniel J. Bernstein](http://cr.yp.to/snuffle/spec.pdf)
- [D. Boneh and Victor Shoup. A Graduate Course in Applied Cryptography.
Available](https://toc.cryptobook.us/)