# RSA Blinding Attack
## Intro
RSA (named after its creators Ronald Linn Rivest, Adi Shamir and Leonard Adleman) is an example of asymmetric cryptosystem, that can be used for secure communications and signing data. The basic principle behind it is as follows.

RSA uses the multiplicative group modulo $N=pq$, where $p$ and $q$ are prime numbers. The order of the multiplicative group (effectively the number of elements in this group) can be calculated with Euler's totient function: $\varphi(N)=(p-1)(q-1)$. We just cross out the numbers less than $N$ that are multiples of $p$ or $q$. All the other numbers are in the multiplicative group, since they are coprime with $N$. As we know, if we exponentiate any member of a group to the order of the group, the result is group's identity: $a^{\varphi(N)}=1, a \in Z^{*}_{N}$. So two numbers, $e$ (public exponent) and $d$ (private exponent) such that $ed=1 mod N$ are chosen. The pair of numbers $(e,N)$ becomes the public key and $(d,N)$ becomes the private key. Calculating $d$ from public information is considered a hard problem (unless certain mistakes were made). You would to have to factor $N$ to get $p$ and $q$.

Given message $M, M < N$, public key $(e,N)$ and private key $(d,N)$ the encryption and decryption process is as follows:

Encryption
$C=M^{e}\space mod\space N$

Decryption

$M=C^{d} \space mod\space N$

It's quite easy to check soundness:

$C^{d}\space mod\space N=M^{ed}\space mod\space N= M^{ed\space mod\space \varphi(N)}\space mod\space N=M^{1}\space mod\space N= M\space mod\space N$

## Preparation
Let's try to work with RSA for a bit. If you haven't yet, install Pycryptodome. On Linux and Windows this should work (You obviously have to install python 3 first and pip, but I hope you know how to do that):


In [1]:
!python3 -m pip install pycryptodome

Collecting pycryptodome
  Using cached https://files.pythonhosted.org/packages/54/e4/72132c31a4cedc58848615502c06cedcce1e1ff703b4c506a7171f005a75/pycryptodome-3.9.6-cp36-cp36m-manylinux1_x86_64.whl
Installing collected packages: pycryptodome
Successfully installed pycryptodome-3.9.6


You'll have to restart jupyter kernel after installation (circly arrow near "Run"). If you encounter problems, you can follow this installation guide: [Pycryptodome installation](https://pycryptodome.readthedocs.io/en/latest/src/installation.html).

## Primitive RSA
So let's try to implement RSA. Let's use the public exponent $e=65537$. This constant is usually chosen nowadays, because it wraps the modulus even  if $M=2$ and it has a nice binary representation $65537_{10}=10000000000000001_{2}$ which allows for efficient exponentiation using the "Square and multiply" method.

First, generate the $p$ and $q$. getStrongPrime function lets you choose the number of bits in your prime and checks that $gcd(p-1,e)=1$

In [2]:
try:
    from Crypto.Util.number import getStrongPrime, inverse,bytes_to_long, long_to_bytes
except ImportError:
    print ("Pycryptodome not installed")

In [3]:
e=65537
p=getStrongPrime(1024,e=e)
q=getStrongPrime(1024,e=e)

In [4]:
N=p*q
phi=(p-1)*(q-1)
d=inverse(e,phi)
public_key=(e,N)
private_key=(d,N)

Ok, we've generated the keys, let's encrypt a message, decrypt the ciphertext and check if it is the same message

In [5]:
M=bytes_to_long(b'Hello, RSA!')
C=pow(M,e,N)
print ('C:',hex(C))
M1=pow(C,d,N)
assert M1==M
print ('M1:',long_to_bytes(M1))

C: 0x57d826694f86a3120e40f5a86bfbd451a91e40886345fcbeb360582f9334bfd5d0734a0f2d11068da5aa851fb07d265ac3cddf47092a9d6e8049801bf721db5a8ae2a78dbae899212a14c22e9a7128f54158b143f22410997184c75b1945a30b940d921e43e05a401bffb3c356ed2134d503c4a112b9b3782e3c85a9b5985d7b836a6b9fdffe941bfaac5555583716e8667c9ba8076cc9ad1063428abde01e2639f07afd55bff37c58d01dc4300563a8d01bc253689024b5d124639723f67bc3a69d867cb8c89a4f3f373dac4d054931e9c452f5fcbba384d78a219d515fa8a0803f3095531b5e865e0ad8430296e6e1b3ee60a193376624db9561395abbe5f9
M1: b'Hello, RSA!'


Creating a signature is inverse operation to encryption. 
$$Sign(M)\equiv Dec(M),\space Check(S) \equiv Enc(S)$$
This way anyone with a public key can check the validity of the signature, but only the entity holding the private key can produce signatures.
Congratulations. You know how to encrypt and create signatures with RSA. Now let's explore one of RSA's interesting properties.
## RSA Blinding
RSA is homomorphic encryption under multiplication.
Homomorphism is a structure-preserving map between two algebraic structures of the same type. (If you didn't get anything from that, that's ok. I also didn't the first time I heard that). What this means is that if you have to elements $(x,y)$ and you put them through a homomorphic function, they will relate the same to each other in the new group/ring/field, etc. as they were in the original one:
$$\varphi(x\cdot y)=\varphi(x)\times\varphi(y)$$
This translates to $$Enc(M_1 \cdot M_2)=Enc(M_1)\times Enc(M_2)=$$
The same is true for decryptions:
$$Dec(C_1 \times C_2)=Dec(C_1) \cdot Dec(C_2)$$
Let's try this in python

In [6]:
class BasicRSA:
    def __init__(self, e,p,q):
        self.e=e
        self.p=p
        self.q=q
        self.N=p*q
        self.d=inverse(e,(p-1)*(q-1))
    
    def encryptNumber(self, m):
        return pow(m,self.e,self.N)
    
    def decryptNumber(self, c):
        return pow(c,self.d,self.N)

brsa=BasicRSA(e,p,q) #we created these parameters earlier
m1=2
m2=3
m3=m1*m2
c1=brsa.encryptNumber(m1)
c2=brsa.encryptNumber(m2)
print ('c1:',c1)
print ('c2:',c2)
c3=(c1*c2)%brsa.N
print('c3:',c3)
m3_dec=brsa.decryptNumber(c3)
print ('m3: %d, m3_dec: %d'%(m3,m3_dec))
assert m3_dec==m3

c1: 11014358485183199846714767764732588544363560216734986997751759838980896247036812118141673755964065645338643639675248246068562288738065082417407730360184076214990588791975181892296042721563827805731740951124609706579901318653758456775330473710491485431328583844391540358718461324637241675512940099643039749020599646650211762103144205410613508156429764270484909050485216692928577495624223260765599304190305618964836579678856974569227509451924698386106294477277334433622387815439366008962921414541506876469089551205189960717533351392543115659951741153005390374003996592787607476215857261150299842299244927932785444540546
c2: 22269917331550587681062869725336176848700267356150185557634432741963047730503809242923347246719638703856668622377030750771723279059943151242898428084220650713087458838740534416590669893631187830837473370479394255480342902372400549804753544166839643565700375554705922652925467408548277577091673397336101368127019175007277668743496322415140615497487853756415578742443831527343

## Attacking the server
Now try to apply this knowledge to a vulnerable server. You can connect to it with ```nc cryptotraining.zone 1337``` or by using python sockets.

In [7]:
import socket
import re
class VulnServerClient:
    def __init__(self,show=True):
        """Initialization, connecting to server"""
        self.s=socket.socket(socket.AF_INET,socket.SOCK_STREAM)
        self.s.connect(('cryptotraining.zone',1337))
        if show:
            print (self.recv_until().decode())
    def recv_until(self,symb=b'\n>'):
        """Receive messages from server, by default till new prompt"""
        data=b''
        while True:
            
            data+=self.s.recv(1)
            if data[-len(symb):]==symb:
                break
        return data
    def get_public_key(self,show=True):
        """Receive public key from the server"""
        self.s.sendall('public\n'.encode())
        response=self.recv_until().decode()
        if show:
            print (response)
        e=int(re.search('(?<=e: )\d+',response).group(0))
        N=int(re.search('(?<=N: )\d+',response).group(0))
        self.num_len=len(long_to_bytes(N))
        return (e,N)
    
    def signBytes(self,m,show=True):
        """Get a signature for chosen byte message from the server"""
        try:
            num_len=self.num_len
        except KeyError:
            print ('You need to get the public key from the server first')
            return
        if len(m)>num_len:
            print ("The message is too long")
            return
        if len(m)<num_len:
            m=bytes((num_len-len(m))*[0x0])+m
        hex_m=m.hex().encode()
        self.s.sendall(b'sign '+hex_m+b'\n')
        response=self.recv_until().decode()
        if show:
            print (response)
        if response.find('flag')!=-1:
            print('You tried to submit \'flag\'')
            return None
        signature_hex=re.search('(?<=Signature: )[0-9a-f]+',response).group(0)
        signature_bytes=bytes.fromhex(signature_hex)
        return bytes_to_long(signature_bytes)
    
    
    def signNumber(self,m,show=True):
        """Get a signature for chosen number from the server"""
        try:
            num_len=self.num_len
        except KeyError:
            print ('You need to get the public key from the server first')
            return
        return self.signBytes(long_to_bytes(m,num_len),show)
        
    def checkSignatureNumber(self,c,show=True):
        """Check if this number is a valid signature for 'flag'"""
        try:
            num_len=self.num_len
        except KeyError:
            print ('You need to get the public key from the server first')
            return
        signature_bytes=long_to_bytes(c,num_len)
        self.checkSignatureBytes(signature_bytes,show)
    
    def checkSignatureBytes(self,c,show=True):
        """Check if these bytes are a valid signature for 'flag'"""
        try:
            num_len=self.num_len
        except KeyError:
            print ('You need to get the public key from the server first')
            return
        if len(c)>num_len:
            print ("The message is too long")
            return
        
        hex_c=c.hex().encode()
        self.s.sendall(b'flag '+hex_c+b'\n',)
        response=self.recv_until(b'\n').decode()
        
        if show:
            print (response)
        
        if response.find('Wrong')!=-1:
            print('Wrong signature')
            x=self.recv_until()
            if show:
                print (x)
            return
        flag=re.search('CRYPTOTRAINING\{.*\}',response).group(0)
        print ('FLAG: ',flag)
        
    def __del__(self):
        self.s.close()


In [8]:
vs=VulnServerClient()
(e,N)=vs.get_public_key()

Welcome to RSA blinding task
Available commands:
help - print this help
public - show public key
sign <hex(data)> - sign data
flag <hex(signature(b'flag'))> - print flag 
quit - quit
>
e: 65537
N: 20159717663186764200842482638329142432479376755681286432561400011207751568770239378735042390550988864636478212097889382541806378632813451522011734778394352464750695430236459156439656932108536936107092785759187120915559173321302027525229018106368725032056109022369913503577023942696069608771010384365856481001383579432844112231215767630328627015097422540087789462404508697086321213990868031273219614897901436844999442259387453021270642395531884848697650933478124254071912232445708062597679170291021925633789812405697682134528381868778865376836541179591638312152472136313757252384761293684336082840137773984575947459061
>


You can sign messages with signNumber and signBytes methods.

You can check signatures with checkSignatureNumber and checkSignatureBytes methods.

Your goal is to get the valid signature for message 'flag'

Remember that RSA is homomorphic and solve the task.

Good luck!