In [None]:
###################Encryption, Update and Keystream function definition Start#####################

def h(x):
    return (x[0]&x[1] ^ x[2]&x[3] ^ x[4]&x[5] ^ x[6]&x[7] ^ x[0]&x[4]&x[8])

def keystream(L,N):                                      #Keystream function
    hx = h([N[12], L[8], L[13], L[20], N[95], L[42], L[60], L[79], L[94]])
    temp = hx ^ L[93] ^ N[2] ^ N[15] ^ N[36] ^ N[45] ^ N[64] ^ N[73] ^ N[89]
    return temp

def LFSR_bk_update(L):                                   #LFSR update function in backward direction
    temp = L[127] ^ L[6] ^ L[37] ^ L[69] ^ L[80] ^ L[95]
    return temp

def LFSR_fw_update(L):                                   #LFSR update function in forward direction
    temp = L[0] ^ L[7] ^ L[38] ^ L[70] ^ L[81] ^ L[96]
    return temp

def NFSR_bk_update(N, l0):                               #NFSR update function in backward direction
    temp = l0 ^ N[127] ^ N[25] ^ N[55] ^ N[90] ^ N[95] ^ N[2]&N[66] ^ N[10]&N[12] ^ N[16]&N[17] \
    ^ N[26]&N[58] ^ N[39]&N[47] ^ N[60]&N[64] ^ N[67]&N[83] ^ N[87]&N[91]&N[92]&N[94] \
    ^ N[21]&N[23]&N[24] ^ N[69]&N[77]&N[81]
    return temp

def NFSR_fw_update(N,l0):                                #NFSR update function in forward direction
    temp = l0 ^ N[0] ^ N[26] ^ N[56] ^ N[91] ^ N[96] ^ N[3]&N[67] ^ N[11]&N[13] ^ N[17]&N[18] \
    ^ N[27]&N[59] ^ N[40]&N[48] ^ N[61]&N[65] ^ N[68]&N[84] ^ N[22]&N[24]&N[25]\
    ^ N[70]&N[78]&N[82] ^ N[88]&N[92]&N[93]&N[95]
    return temp

def Encryption():
    import random
    #Key/IV setup phase
    K = [random.randint(0,1) for i in range(128)]
    IV = [random.randint(0,1) for i in range(96)]
    
    N = K[:]
    L = [IV[i] for i in range(96)] + [1 for i in range(96,127)] + [0]
    
    #Initialisation of pre-output generator
    for i in range(256):
        z = keystream(L,N)
        n = NFSR_fw_update(N,L[0])
        l = LFSR_fw_update(L)
        L = L[1:128] + [l^z]
        N = N[1:128] + [n^z]
        
    #Authenticator generation phase
    for i in range(64):
        z = keystream(L,N)
        n = NFSR_fw_update(N,L[0]) ^ K[i]
        l = LFSR_fw_update(L)
        L = L[1:128] + [l]
        N = N[1:128] + [n]
        
    #Keystream phase
    N_round = 100
    for i in range(N_round):
        if i%2 == 0:
            z = keystream(L,N)
        n = NFSR_fw_update(N,L[0])
        l = LFSR_fw_update(L)
        L = L[1:128] + [l]
        N = N[1:128] + [n]
    return N,L             
        
####################Encryption, Update and Keystream function definition End#####################




#Grain-128-AEAD Encryption
#Specificaton
#101-bit internal state
#NFSR 128-bit  LFSR 128-bit
#NFSR N       LFSR L
#Key 128-bit
#IV  96-bit


import math
import random
import time


state_len = 256      #state_size
mc_len = 32         #microcontroller length
N_round = 40           #Total number of rounds needed
r = [128,128]         #Size of the registers LFSR (L) and NFSR (N) respectively

Nblock = math.ceil(state_len/mc_len)      #No. of Blocks


bitlen = math.floor(math.log(mc_len,2)) + 1   #Number of bit that is needed to represent the max. HW of any blocks
           


N_round_f = math.ceil(N_round/2)                #No. of forward rounds
N_round_b = math.floor(N_round/2)               #No. of backward rounds

print("HW/32 Model Grain-128-AEAD")
print("Number of rounds in forward ", N_round_f," and in backward ", N_round_b)
print("Total number of rounds",N_round)

###############Grain-128-AEAD cipher with a random key and IV to generate keystream and HW ####################################

N_org, L_org = Encryption()         #Internal state picked from the pseudorandom phase
                                           
L = L_org[:]
N = N_org[:]

Z_f = [0]*N_round_f              #Array to store the forward keystream bits
Z_b = [0]*N_round_b              #Array to store the backward keystream bits

#Array to store the HW of each consecutive blocks in forward directions
hamm_wt_f = [[0 for i in range(Nblock)] for i in range(N_round_f)]   

#Array to store the HW of each consecutive blocks in backward directions
hamm_wt_b = [[0 for i in range(Nblock)] for i in range(N_round_b)]


#Information on keystream and HW/32 in backward direction
for i in range(N_round_b): 
    l = LFSR_bk_update(L)                 #Feedback function of LFSR
    n = NFSR_bk_update(N,l)               #Feedback function of NFSR 
    L = [l] + L[:r[0]-1]                  #State update
    N = [n] + N[:r[1]-1] 

    temp = N + L
    
    if i%2 == 0:
        Z_b[i] = keystream(L,N)            #Storing the keystream bits at each round in backward direction

    
    #Storing the HW of each block for internal  states in backward direction
    count = 0
    for l in range(Nblock-1):
        for k in range(mc_len):
            hamm_wt_b[i][l] += temp[count]                
            count += 1
            
    while(count < state_len):                     #Helpful when state_len is not a multiple of mc_len
        hamm_wt_b[i][-1] += temp[count]
        count += 1
    
    
            
L = L_org[:]
N = N_org[:]
#Information on keystream and HW/32 in forward direction
for i in range(N_round_f):
    if i%2 == 0:
        Z_f[i] = keystream(L,N)            #Storing the keystream bits at each round in forward direction
    temp = N + L
    
    count = 0
    #Storing the actual HW of each block for internal  states in forward direction
    for l in range(Nblock-1):
        for k in range(mc_len):
            hamm_wt_f[i][l] += temp[count]
            count += 1
            
    while(count < state_len):                    #Helpful when state_len is not a multiple of mc_len
        hamm_wt_f[i][-1] += temp[count]
        count += 1
            
    if i < N_round_f -1:
        #updating the registers
        l = LFSR_fw_update(L)
        n = NFSR_fw_update(N,L[0])
        
        L = L[1:r[0]] + [l]
        N = N[1:r[1]] + [n]
        


###############SMT Modelling to recover state bits from keystream and HW/32#########################################
from z3 import *
m = Solver()

#Defining variables of the format BitVec(.,1)
L_var_org = [BitVec('l%d'%i,1) for i in range(r[0])]          #For register L
N_var_org = [BitVec('n%d'%i,1) for i in range(r[1])]          #For register N

L_f = [BitVec('L_f%d'%i,1) for i in range(1,N_round_f,1)]       #Dummy variables for LFSR L in forward direction
N_f = [BitVec('N_f%d'%i,1) for i in range(1,N_round_f,1)]       #Dummy variables for NFSR N in forward direction


L_b = [BitVec('L_b%d'%i,1) for i in range(1,N_round_b+1,1)]     #Dummy variables for LFSR L in backward direction
N_b = [BitVec('N_b%d'%i,1) for i in range(1,N_round_b+1,1)]     #Dummy variables for NFSR N in backward direction




L = L_var_org[:]
N = N_var_org[:]


#To check the satisfiability of equations guess all variables otherwise guess as per the requirement
#L_org, N_org are the original register of the targeted internal state


#print("The guessed bits are ")
#print("For register L ",end=" = ")
#for i in range(r[0]):
#    print(i,end=", ")
#    m.add(L[i] == L_org[i])
#print("\n For register N ",end=" = ")
#for i in range(r[1]):
#    print(i,end=", ")
#    m.add(N[i] == N_org[i])


#Array to store the keystream equations    
Z_f_equ = [0]*N_round_f        
Z_b_equ = [0]*N_round_b

#Array to store the dummy variable equations for both register L and N
L_f_equ = [0]*(N_round_f-1)
N_f_equ = [0]*(N_round_f-1)

L_b_equ = [0]*N_round_b
N_b_equ = [0]*N_round_b

#Array to store the HW/32 equations
hamm_wt_f_equ = [[0 for i in range(Nblock)] for i in range(N_round_f)] 
hamm_wt_b_equ = [[0 for i in range(Nblock)] for i in range(N_round_b)] 


#Equation generation steps for backward internal states
for i in range(N_round_b):
    
    #Dummy Variable equated to the feedback function
    l = LFSR_bk_update(L)
    n = NFSR_bk_update(N,l)
    
    L_b_equ[i] = (L_b[i] == l)          
    N_b_equ[i] = (N_b[i] == n)
    
    #State updated with dummy variables
    L = [L_b[i]] + L[:r[0]-1]
    N = [N_b[i]] + N[:r[1]-1]

    #Hamming weight equation
    temp = N + L
    count = 0
    for l in range(Nblock-1):            
        for k in range(mc_len):
            hamm_wt_b_equ[i][l] += ZeroExt(bitlen-1,temp[count])
            count+=1
        m.add(hamm_wt_b_equ[i][l] == hamm_wt_b[i][l])
    
    while(count < state_len):                              #Helpful when state_len is not a multiple of mc_len
        hamm_wt_b_equ[i][-1] += ZeroExt(bitlen-1,temp[count])
        count += 1
    m.add(hamm_wt_b_equ[i][-1] == hamm_wt_b[i][-1])
    
        
    #Keystream equation
    if i%2 == 0:
        Z_b_equ[i] = (keystream(L,N) == Z_b[i])
        m.add(Z_b_equ[i])

    
L = L_var_org[:]
N = N_var_org[:]
#Equations generation steps for forward internal states
for i in range(N_round_f):
    if i%2 == 0:
        Z_f_equ[i] = (keystream(L,N) == Z_f[i])             #Keystream equation
        m.add(Z_f_equ[i])
    
    #Hamming weight equation
    temp = N + L
    count = 0
    for l in range(Nblock-1):
        for k in range(mc_len):
            hamm_wt_f_equ[i][l] += ZeroExt(bitlen-1,temp[count])
            count+=1
        m.add(hamm_wt_f_equ[i][l] == hamm_wt_f[i][l])
        
    while(count< state_len):                            #Helpful when state_len is not a multiple of mc_len
        hamm_wt_f_equ[i][-1] += ZeroExt(bitlen-1,temp[count])
        count += 1
    m.add(hamm_wt_f_equ[i][-1] == hamm_wt_f[i][-1])
    
    if i < N_round_f-1:
        l = LFSR_fw_update(L)
        n = NFSR_fw_update(N,L[0])
        
        L_f_equ[i] = (L_f[i] == l)            #Dummy variable equated to the feedback function
        N_f_equ[i] = (N_f[i] == n)
        
        #Internal state updated with dummy variable
        L = L[1:] + [L_f[i]]
        N = N[1:] + [N_f[i]]


# All Dummy variable equations as keystream equations are passed to the solver in the loop.
Equ = L_f_equ + N_f_equ + L_b_equ + N_b_equ


m.add(Equ)              #Equation feeded to the model m
start = time.time()
temp = m.check()        #Check the satisfiability of the given system of equations
end = time.time()

#If Satisfiable
if temp == sat:         
    print(temp)
    soln = m.model()           #soln contains the output
    print("SMT Time = ", end - start)         #Time taken by solver to solve
    
    #Extracting the value of the internal state variables
    out_L = [0]*(r[0])
    out_N = [0]*(r[1])
    for i in range(r[0]):
        out_L[i] = soln[L_var_org[i]]
    for i in range(r[1]):
        out_N[i] = soln[N_var_org[i]]
   
    #Check whether the original state is same as the output state
    #L_org, N_org are the original registers whereas out_L and out_N are the register from output of SMT
    if L_org == out_L and N_org == out_N:
        print("matched")
    else:
        print("############################not matched##############################################")
        print(L_org)
        print(N_org)
        print(out_L)
        print(out_N)
    
#If the solution is Unsatisfiable, check code and equations
else:
    print("###################################unsat###################################################")
    print(temp)
    print(L_org)
    print(N_org)