In [None]:
''' This code tries to get U, V together by using the output of SubCells in second round as the intermediate value
It guesses two bytes (u, v) at a time, one each from U and V '''

import numpy as np
import matplotlib.pyplot as plt
from tqdm.notebook import tnrange
from scipy.stats import linregress
import seaborn as sns
import time
import rich as r
import pandas as pd
from IPython.display import clear_output # type: ignore
from bokeh.plotting import figure, show
from bokeh.io import output_notebook
from bokeh.palettes import brewer
import itertools

trace_array = np.load('.\traces\trace.npy')
textin_array = np.load('.\traces\nonce_text.npy')

numtraces = 1000  # number of traces to use

# For styling the output only
fmt = "{:02X}<br>{:.5f}"
def format_stat(stat):
    return str(fmt.format(stat[0], stat[1]))

def color_corr_key(row):
    ret = [""] * len(row)
    for i, bnum in enumerate(row):
        if i == 0:
            ret[i] = "color: green"
        else:
            ret[i] = "color: red"
    return ret

# Round constants
GIFT_RC = [
    0x01, 0x03, 0x07, 0x0F, 0x1F, 0x3E, 0x3D, 0x3B, 0x37, 0x2F,
    0x1E, 0x3C, 0x39, 0x33, 0x27, 0x0E, 0x1D, 0x3A, 0x35, 0x2B,
    0x16, 0x2C, 0x18, 0x30, 0x21, 0x02, 0x05, 0x0B, 0x17, 0x2E,
    0x1C, 0x38, 0x31, 0x23, 0x06, 0x0D, 0x1B, 0x36, 0x2D, 0x1A
]

def rowperm(S, B0_pos, B1_pos, B2_pos, B3_pos):
    T=0
    for b in range(8):
        T |= ((S>>(4*b+0))&0x1)<<(b + 8*B0_pos)
        T |= ((S>>(4*b+1))&0x1)<<(b + 8*B1_pos)
        T |= ((S>>(4*b+2))&0x1)<<(b + 8*B2_pos)
        T |= ((S>>(4*b+3))&0x1)<<(b + 8*B3_pos)
    return T

def intermediate(P, u, v, index):
    S = np.empty(16)
    S = S.astype(np.int64)

    # ===SubCells=== - Round 1 #
    S[0] = (P[ 0]<<24) | (P[ 1]<<16) | (P[ 2]<<8) | P[ 3]
    S[1] = (P[ 4]<<24) | (P[ 5]<<16) | (P[ 6]<<8) | P[ 7]
    S[2] = (P[ 8]<<24) | (P[ 9]<<16) | (P[10]<<8) | P[11]
    S[3] = (P[12]<<24) | (P[13]<<16) | (P[14]<<8) | P[15]

    S[1] ^= S[0] & S[2]
    S[0] ^= S[1] & S[3]
    S[2] ^= S[0] | S[1]
    S[3] ^= S[2]
    S[1] ^= S[3]
    S[3] ^= 0xffffffff
    S[2] ^= S[0] & S[1]

    T = S[0]
    S[0] = S[3]
    S[3] = T

    # ===PermBits=== - Round 1 #
    S[0] = rowperm(S[0],0,3,2,1)
    S[1] = rowperm(S[1],1,0,3,2)
    S[2] = rowperm(S[2],2,1,0,3)
    S[3] = rowperm(S[3],3,2,1,0)


    # ===AddRoundKey=== - Round 1 #

    if index == 0:      # rightmost byte
        S2 = ((S[2] >>  0) & 0b11111111) ^ u
        S1 = ((S[1] >>  0) & 0b11111111) ^ v
    elif index == 1:
        S2 = ((S[2] >>  8) & 0b11111111) ^ u
        S1 = ((S[1] >>  8) & 0b11111111) ^ v
    elif index == 2:
        S2 = ((S[2] >> 16) & 0b11111111) ^ u
        S1 = ((S[1] >> 16) & 0b11111111) ^ v
    elif index == 3:    # leftmost byte
        S2 = ((S[2] >> 24) & 0b11111111) ^ u
        S1 = ((S[1] >> 24) & 0b11111111) ^ v

    # Add round constant - Round 1 #
    S[3] ^= 0x80000000 ^ GIFT_RC[1]

    if index == 0:      # rightmost byte
        S3 = (S[3] >>  0) & 0b11111111
        S0 = (S[0] >>  0) & 0b11111111
    elif index == 1:
        S3 = (S[3] >>  8) & 0b11111111
        S0 = (S[0] >>  8) & 0b11111111
    elif index == 2:
        S3 = (S[3] >> 16) & 0b11111111
        S0 = (S[0] >> 16) & 0b11111111
    elif index == 3:    # leftmost byte
        S3 = (S[3] >> 24) & 0b11111111
        S0 = (S[0] >> 24) & 0b11111111

    # ===SubCells=== - Round 2 #
    S1 ^= S0 & S2
    S0 ^= S1 & S3

    return S0

def hamming_weight(iv):
  return bin(iv).count("1")

printable = []
DOM = []
key_guess = []

for index in range(4): # taking one byte each u, v of the 4 byte subkeys U, V
    temp = []
    mean_diffs = np.zeros(256*256)
    for u,v in itertools.product(range(256), range(256)): # iterating through all possible values of u and v
        one_list = []
        zero_list = []
        for trace_no in range(numtraces):
            hw = hamming_weight(intermediate(textin_array[trace_no],u,v,index))
            if hw > 4:
                one_list.append(trace_array[trace_no])
            else:
                zero_list.append(trace_array[trace_no])

        one_avg = np.asarray(one_list).mean(axis=0)
        zero_avg = np.asarray(zero_list).mean(axis=0)
        mean_diffs[u*256+v] = np.max(abs(one_avg - zero_avg)) # dom corresponding to keyguesses u,v is stored in u*256 + v
        temp.append((u, v, mean_diffs[u*256+v])) # adding keyguesses u,v and dom

    DOM.append(mean_diffs)
    temp.sort(key = lambda x: -x[1])  # sort temp by dom value
    printable.append(temp)  # add the data in list
    df = pd.DataFrame(printable).transpose()

    guess = np.argsort(mean_diffs)[-1]
    key_guess.append(guess)

    clear_output(wait=True)  # clear the previous output
    display(df.head().style.format(format_stat).apply(color_corr_key, axis=0))  # display the current status
