<a href="https://colab.research.google.com/github/aaydenn/rsa-algorithm/blob/main/RSA_algorithm.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# RSA algorithm

## Generate keys



1. Choose two large prime numbers *p* and *q*.
2. Compute *n = pq*.
 - n is used as the modulus for both the public and private keys
3. Compute λ(n) = least common multiple(p-1, q-1).
 - λ(n) is kept secret.
4. Choose an integer e such that 2 < e < λ(n) and greatest common divisor(e, λ(n)) = 1.
 - e should have a short bit-length and small Hamming weight.
 - e is a part of public key.
5. Determine d as d ≡ e-1 (mod λ(n)).
 - solve for d the equation de ≡ 1 (mod λ(n))
 - d is a part of private key.


In [152]:
from dataclasses import dataclass
from sympy import isprime
from math import gcd, isqrt, prod

@dataclass
class RSA:
  p: int
  q: int
  kwargs: dict = None


  def __post_init__(self):
    """automatically after initialization"""
    self.is_prime()
    self.is_ordered()


  def is_prime(self):
    """checks if numbers are prime"""
    if not isprime(self.p) or not isprime(self.q):
      raise ValueError("Both numbers must be prime.")


  def is_ordered(self):
    """checks if numbers are ordered"""
    if self.p < self.q:
      self.p, self.q = self.q, self.p


  def n(self, p = None, q = None):
    """mode value for private and public keys"""
    if p == None:
      p = self.p

    if q == None:
      q = self.q

    return p * q


  def totient(self, n = None):
    """totient function"""
    if n == None:
      n = self.n()

    # since totient function of a prime number
    # is one minus that prime number,
    # totient of two primes is the product
    # of the primes minus one.
    if isprime(n):
      return n - 1

    p = [i for i in range(2, n + 1)
      if n % i == 0 and isprime(i)]

    phi = n * prod([(i-1)/i for i in p])

    return int(phi)


  def e(self, phi = None, choose = True):
    """public key, power of plaintext"""
    if phi == None:
      phi = self.totient()

    # smaller than phi and not a factor
    E = [i for i in range(2, phi) if gcd(i, phi) == 1]

    if choose:
      # turn to bits to choose most complex
      e_bits = [bin(e)[2:] for e in E]

      # count non-zero values
      hamming_w = [sum(1 for n in i if n != '0') for i in e_bits]

      D = dict(zip(E, hamming_w))

      return min(D, key = lambda k: (-D[k], k))

    else:
      return E


  def d(self, n = None, e = None, phi = None):
    """private key, power of cyphertext"""
    if n == None:
      n = self.n()

    if e == None:
      e = self.e()

    if phi == None:
      phi = self.totient()

    # d = (1 + k*mod phi) / e
    # called extended euclid algorithm
    k = next(k for k in range(1, phi)
      if (k * phi + 1) % e == 0)
    d = (k * phi + 1) // e
    return k, d


  def keys(self, n = None, e = None, d = None):
    """return public and private keys"""
    if n == None:
      n = self.n()

    if e == None:
      e = self.e()

    if d == None:
      _, d = self.d()

    return (n, e), (n, d)


  def encode(self, public:tuple = None, message:str = None):
    """convert message to unicode then encrypt"""
    if public == None:
      public, _ = self.keys()

    #TODO: validate message

    # encode plain message in unicode
    M = [ord(l) for l in message]

    # encrypt encoded message
    return [((l ** public[1]) % public[0]) for l in M]


  def decode(self, private:tuple = None, message = None, readable = False):
    """convert message to unicode then decrypt"""
    if private == None:
      _, private = self.keys()

    #TODO: validate message

    # decode cypher message in unicode
    M = [((l ** private[1]) % private[0]) for l in message]

    # translate to human readable
    if readable:
      return "".join([chr(l) for l in M])

    return M

In [153]:
rsa = RSA(17,37)
public, private = rsa.keys()

print(f"public key: {public}")
print(f"private key: {private}")

public key: (629, 511)
private key: (629, 319)


## Encrypt/decrypt plain text

To encrypt plain text: $$ c = m^e\bmod n $$
To decrypt cypher text: $$ m = c^d\bmod n $$

In [156]:
message = "non est servus tuus motus"
enc = rsa.encode(message = message)
dec = rsa.decode(message = enc)
dec_rdble = rsa.decode(message = enc, readable = True)

print(f"message: {message}")
print(f"encrypted message: {enc}")
print(f"decrypted message: {dec}")
print(f"human-readable decrypted message: {dec_rdble}")

message: non est servus tuus motus
encrypted message: [406, 444, 406, 93, 101, 548, 351, 93, 548, 101, 78, 441, 586, 548, 93, 351, 586, 586, 548, 93, 464, 444, 351, 586, 548]
decrypted message: [110, 111, 110, 32, 101, 115, 116, 32, 115, 101, 114, 118, 117, 115, 32, 116, 117, 117, 115, 32, 109, 111, 116, 117, 115]
human-readable decrypted message: non est servus tuus motus
