<a href="https://colab.research.google.com/github/AmrMomtaz/RSA/blob/main/RSA.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# RSA

Implementation of the RSA algorithm which generates PK/SK and encrypts and decrypts a given message.

In [None]:
# Defining constants
MESSAGE = 'This assignment is based on Computer and Network Security by Avinash Kak.'
KEY_SIZE = 4096 # Size of Public & Private keys (Number of bits)
BLOCK_SIZE = KEY_SIZE # Number of bits for which the message is divided into blocks of it
TRIALS = 5 # Number of trials used in Miller-Rabin primality test

# Importing libraries
import random
import math

In [None]:
# Generating two large prime number p & q
def miller_rabin_primality_test(n):
  if n < 2:
      return False
  if n == 2 or n == 3:
      return True
  if n % 2 == 0:
      return False

  r, d = 0, n-1
  while d % 2 == 0:
      r += 1
      d //= 2

  for _ in range(TRIALS):
      a = random.randint(2, n-2)
      x = pow(a, d, n)
      if x == 1 or x == n-1:
          continue
      for _ in range(r-1):
          x = pow(x, 2, n)
          if x == n-1:
              break
      else:
          return False
  return True


def generate_prime():
  while True:
      r = random.getrandbits(KEY_SIZE)
      if miller_rabin_primality_test(r):
          return r

p = generate_prime()
q = generate_prime()
while p == q:
  p = generate_prime()
  q = generate_prime()

In [None]:
# Generating the public exponent (e)
def extended_gcd(a, b):
  if b == 0:
      return a, 1, 0
  else:
      gcd, x1, y1 = extended_gcd(b, a % b)
      x = y1
      y = x1 - (a // b) * y1
      return gcd, x, y

probes = [3, 17, 65537] # e will be chosen from this list
e = None
for probe in probes:
  if extended_gcd(probe, p-1)[0] == 1 and extended_gcd(probe, q-1)[0] == 1:
    e = probe
    break

if e is None:
  raise Exception('The exponent isn\'t set !')

In [None]:
# Calculating (d) and defining public and private keys
d = extended_gcd(e, (p-1)*(q-1))[1]
if d < 0:
  d += (p-1)*(q-1)

n = p*q
PK = (e, n) # Public Key
SK = (d, n) # Private Key
print(f'Public Key: {PK}\nPrivate Key: {SK}')

Public Key: (3, 241487482056250721310678125990796370625741926099575548066338827838136189343571086105440914981487988211723771963075688285993567822352305304226260624443194480542766886785746507207267908466234761133633454495499754808120252849339105825417216881524829741005285805888489278686189410073053134257214381537960787219224742347918172247744216699703632274917318073752999678419297924101413243885035855792736062424971042667858658712319728838338178430457787934842471194003637112751828612943577426504806966100009026661212925348788253497145269406447744462804690582651990518485771516312209972720360903836259236641666917766648929567983835716427944399892586097573257668031429853732029893242701691676610846622227578680453302611414873678139577230533409997980502592742936058725749834982559327306520298217461144446128485517338120294512465147736772457745467708509894245855055082398217627452431280206404835394153501108133441456073881814413825135074336235270403427841464145812091010158332616176577317402510195404

In [None]:
# Defining functions which transform a given message to list of blocks of size
# ${BLOCK_SIZE} and vice versa (It pads the message if necessary).
def pad_message(message):
  binary_message = ''.join(format(ord(c), '08b') for c in message)
  padding_length = (BLOCK_SIZE - (len(binary_message) + int(BLOCK_SIZE/8)) % BLOCK_SIZE) % BLOCK_SIZE
  padded_message = binary_message + '1' + '0'*padding_length + format(len(binary_message), '064b')
  blocks = [padded_message[i:i+BLOCK_SIZE] for i in range(0, len(padded_message), BLOCK_SIZE)]
  return blocks

def unpad_message(blocks):
  binary_message = ''.join(blocks)
  message_length = int(binary_message[-int(BLOCK_SIZE/8):], 2)
  binary_chunks = [binary_message[i:i+8] for i in range(0, message_length, 8)]
  message = ''.join(chr(int(chunk, 2)) for chunk in binary_chunks)
  return message

In [None]:
# Defining the modular exponentiation function
def modular_exp(base, exponent, modulus):
  result = 1
  while exponent > 0:
      if exponent % 2 == 1:
          result = (result * base) % modulus
      base = (base * base) % modulus
      exponent //= 2
  return result

In [None]:
# Defining the encrypt and decrypt functions using PK & SK
def encrypt(message, PK):
  e, n = PK
  encrypted_blocks = list()
  blocks = pad_message(message)
  for block in blocks:
    encrypted_blocks.append(modular_exp(int(block, 2), e, n))
  return encrypted_blocks

def decrypt(enrypted_blocks, SK):
  d, n = SK
  decrypted_blocks = list()
  for enrypted_block in enrypted_blocks:
    decrypted_block = modular_exp(enrypted_block, d, n)
    decrypted_blocks.append(format(decrypted_block, f'{BLOCK_SIZE}b').replace(' ', '0'))
  if decrypted_blocks[-1] == '0'*BLOCK_SIZE:
    decrypted_blocks.pop()
    decrypted_blocks.append('0')
  else:
    decrypted_blocks.append('0' + bin(int(decrypted_blocks.pop(), 2))[2:])
  return unpad_message(decrypted_blocks)

In [None]:
# Running the encryption and decryption
encrpyted_blocks = encrypt(MESSAGE, PK)
print(f'Encrypted Blocks: {encrpyted_blocks}')
print(f'Decrypted Message: {decrypt(encrpyted_blocks, SK)}')

Encrypted Blocks: [188538713900392173293220673221506872068034859273691148671111358150833019154430677642223751320085833704071886602216078582094780946569919944343838468575044307643723562269252255894287130348837038886041352647281010623156435140402401526568491125925949321848989471102902109468334153119020995510578670868123731771394179191203521296694577993599600198025989259190628799932988479004030280844236362105789426002399763844217532781180035439884492425835002751611019761340325017751778636779404740649901375382530498876630094202293851331917296481679507441210568098389821572959619571277290903543694136354839208552147061214093243574124548451286252289353712492719734401514749484768957442690615784102551483598155405864427346711141546963599211248741415781837296824079140511870422639423792105685651875697427093030359970140327224875912424602765537875843687592873723357102197351075960011842199430941041799829969659935480258171972196750255148003488008788122661639344593530104266274460172966819971385623326911