<a href="https://colab.research.google.com/github/AmrMomtaz/RSA/blob/main/RSA.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# RSA

Implementation of the RSA algorithm which generates PK/SK and encrypts and decrypts a given message.

In [634]:
# Defining constants
MESSAGE = 'This assignment is based on Computer and Network Security by Avinash Kak.'
KEY_SIZE = 2048 # Size of Public & Private keys (Number of bits)
BLOCK_SIZE = KEY_SIZE # Number of bits for which the message is divided into blocks of it
TRIALS = 5 # Number of trials used in Miller-Rabin primality test

# Importing libraries
import random
import math

In [635]:
# Generating two large prime number p & q of size 512 bits each
def miller_rabin_primality_test(n):
  if n < 2:
      return False
  if n == 2 or n == 3:
      return True
  if n % 2 == 0:
      return False

  r, d = 0, n-1
  while d % 2 == 0:
      r += 1
      d //= 2

  for _ in range(TRIALS):
      a = random.randint(2, n-2)
      x = pow(a, d, n)
      if x == 1 or x == n-1:
          continue
      for _ in range(r-1):
          x = pow(x, 2, n)
          if x == n-1:
              break
      else:
          return False
  return True


def generate_prime():
  while True:
      r = random.getrandbits(KEY_SIZE)
      if miller_rabin_primality_test(r):
          return r

p = generate_prime()
q = generate_prime()

In [636]:
# Generating the public exponent (e)
def extended_gcd(a, b):
  if b == 0:
      return a, 1, 0
  else:
      gcd, x1, y1 = extended_gcd(b, a % b)
      x = y1
      y = x1 - (a // b) * y1
      return gcd, x, y

probes = [3, 17, 65537] # e will be chosen from this list
e = None
for probe in probes:
  if extended_gcd(probe, p-1)[0] == 1 and extended_gcd(probe, q-1)[0] == 1:
    e = probe
    break

if e is None:
  raise Exception('The exponent isn\'t set !')

In [637]:
# Calculating (d) and defining public and private keys
d = extended_gcd(e, (p-1)*(q-1))[1]
if d < 0:
  d += (p-1)*(q-1)

n = p*q
PK = (e, n) # Public Key
SK = (d, n) # Private Key
print(f'Public Key: {PK}\nPrivate Key: {SK}')

Public Key: (17, 31998388735682309345671832469872225532962525055539846074925987320334114712869452530727346944850981945240026493052386075652352620982474071809497618368434825458939229380044862347392422816864681170456729345299772664398720920026276302165635806508898399770156667119918106134275598836375326063469423685637620398040781211376604014288898354776310638926048563273768073654796800734622528019523143365765767922552578183982893320938881706229431899532364428230774116863870173071732257161132356873903818407687277791472435566792046459678083843995170307614009186814231004934306062889642216531321465808339071404486119896339002863505057233126308495873481401083595663387045290554849698176937716975316268605305858398029179081537333880318322455349934017420216709546265993062132786357935772242564949900048048896336154539879898985539473978618302177080344578685707381358678714620837495823084955578655900322526674994133879272824330143766253734642763437145372662467626036889647380921645175477136982505483440736

In [638]:
# Defining functions which transform a given message to list of blocks of size
# ${BLOCK_SIZE} and vice versa (It pads the message if necessary).
def pad_message(message):
  binary_message = ''.join(format(ord(c), '08b') for c in message)
  padding_length = (BLOCK_SIZE - (len(binary_message) + int(BLOCK_SIZE/8)) % BLOCK_SIZE) % BLOCK_SIZE
  padded_message = binary_message + '1' + '0'*padding_length + format(len(binary_message), '064b')
  blocks = [padded_message[i:i+BLOCK_SIZE] for i in range(0, len(padded_message), BLOCK_SIZE)]
  return blocks

def unpad_message(blocks):
  binary_message = ''.join(blocks)
  message_length = int(binary_message[-int(BLOCK_SIZE/8):], 2)
  binary_chunks = [binary_message[i:i+8] for i in range(0, message_length, 8)]
  message = ''.join(chr(int(chunk, 2)) for chunk in binary_chunks)
  return message

In [639]:
# Defining the modular exponentiation function
def modular_exp(base, exponent, modulus):
  result = 1
  while exponent > 0:
      if exponent % 2 == 1:
          result = (result * base) % modulus
      base = (base * base) % modulus
      exponent //= 2
  return result

In [640]:
# Defining the encrypt and decrypt functions using PK & SK
def encrypt(message, PK):
  e, n = PK
  encrypted_blocks = list()
  blocks = pad_message(message)
  for block in blocks:
    encrypted_blocks.append(modular_exp(int(block, 2), e, n))
  return encrypted_blocks

def decrypt(enrypted_blocks, SK):
  d, n = SK
  decrypted_blocks = list()
  for enrypted_block in enrypted_blocks:
    decrypted_block = modular_exp(enrypted_block, d, n)
    decrypted_blocks.append(format(decrypted_block, f'{BLOCK_SIZE}b').replace(' ', '0'))
  if decrypted_blocks[-1] == '0'*BLOCK_SIZE:
    decrypted_blocks.pop()
    decrypted_blocks.append('0')
  else:
    decrypted_blocks.append('0' + bin(int(decrypted_blocks.pop(), 2))[2:])
  return unpad_message(decrypted_blocks)

In [641]:
# Running the encryption and decryption
encrpyted_blocks = encrypt(MESSAGE, PK)
print(f'Encrypted Blocks: {encrpyted_blocks}')
print(f'Decrypted Message: {decrypt(encrpyted_blocks, SK)}')

Encrypted Blocks: [953611827309861310797523347336842815534464009909877372459091862652501575816364975926039642848117190963168820553013915859895821082527618001673157034195666952752048933996097908399334545065794406998205612989226160274733039283777337520267801787016078775688344051579074330300936892873983929867899563576565482006201518563181636429237290150350729629326865529863824907349420396344052257020622546660144310669428110980647248567007740277990968715812497846176138236400231942803050242404414480467260769591447355997128223120616894052158609088285636028502779327835454909304601003265028723184242028826062487750208473593797016384827888271560879182403184465863524022447849169157635273026242492448175148375275012613545614241064924206971828616725250714730834444611898549652935893229206003942781750824263938795845986574552548226536823455812459753687452537917000593692406754500591121324740417153224345397015949059548635330588471005814548601817966203305159825356860631573798755641152282193991569611727445