In [None]:
import os
import sys

sys.path.append(os.path.abspath('../ciphers'))

import speck3264 as speck3264

cipher_dict = {
    "speck3264":speck3264
}

from DataGenerator import DataGenerator

import numpy as np
from os import urandom
from tensorflow.keras.regularizers import l2
from tensorflow.keras.backend import concatenate
from tensorflow.keras import backend as K
from tensorflow.keras.layers import Dense, AveragePooling1D, Conv1D, MaxPooling1D, Input, Reshape, Permute, Add, Flatten, BatchNormalization, Activation, MultiHeadAttention
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.models import Model, load_model
from tensorflow.keras.callbacks import ModelCheckpoint, LearningRateScheduler, CSVLogger
from pickle import dump
import tensorflow as tf
import tensorflow

def cyclic_lr(num_epochs, high_lr, low_lr):
    res = lambda i: low_lr + ((num_epochs-1) - i % num_epochs)/(num_epochs-1) * (high_lr - low_lr)
    return res

def make_checkpoint(datei):
    res = ModelCheckpoint(datei, monitor='val_loss', save_best_only = True)
    return res

bs = 5000

In [None]:
def make_resnet(num_blocks=2, num_filters=32, num_outputs=1, ds=[64, 64], word_size=64, ks=3, depth=5, reg_param=0.0001, final_activation='sigmoid'):
    inp = Input(shape=(num_blocks * word_size * 2,));
    rs = Reshape((2 * num_blocks, word_size))(inp);
    perm = Permute((2,1))(rs);
    conv0 = Conv1D(num_filters, kernel_size=1, padding='same', kernel_regularizer=l2(reg_param))(perm);
    conv0 = BatchNormalization()(conv0);
    conv0 = Activation('relu')(conv0);
    shortcut = conv0;
    for i in range(depth):
        conv1 = Conv1D(num_filters, kernel_size=ks, padding='same', kernel_regularizer=l2(reg_param))(shortcut);
        conv1 = BatchNormalization()(conv1);
        conv1 = Activation('relu')(conv1);
        conv2 = Conv1D(num_filters, kernel_size=ks, padding='same',kernel_regularizer=l2(reg_param))(conv1);
        conv2 = BatchNormalization()(conv2);
        conv2 = Activation('relu')(conv2);
        shortcut = Add()([shortcut, conv2]);
    dense = Flatten()(shortcut);
    for d in ds:
        dense = Dense(d,kernel_regularizer=l2(reg_param))(dense);
        dense = BatchNormalization()(dense);
        dense = Activation('relu')(dense);
    out = Dense(num_outputs, activation=final_activation, kernel_regularizer=l2(reg_param))(dense);
    model = Model(inputs=inp, outputs=out);
    return(model);

def validate_search(index, cipher, num_epochs, num_rounds=5, num_blocks=2, num_dataset=10**7, diff=(0x0040,0), num_filters=32, ds=[64,64], ks=3, depth=1, reg_param=10**-5, loss='mse', wdir="./freshly_trained_nets/"):
    strategy = tf.distribute.MirroredStrategy(
        devices=["/gpu:0", "/gpu:1", "/gpu:2", "/gpu:3"], 
        cross_device_ops=tf.distribute.HierarchicalCopyAllReduce())
    batch_size = bs * strategy.num_replicas_in_sync
    
    with strategy.scope():
        net = make_resnet(num_blocks=num_blocks, num_filters=num_filters, ds=ds, word_size=cipher.WORD_SIZE(), ks=ks, depth=depth, reg_param=reg_param)
        net.compile(optimizer='adam', loss=loss, metrics=['acc'])
        
    training_generator = DataGenerator(cipher, num_dataset, batch_size, num_rounds, diff)
    validation_generator = DataGenerator(cipher, num_dataset//10, batch_size, num_rounds, diff)
    
    check = make_checkpoint(wdir+'best'+str(num_rounds)+'depth'+str(depth)+'_'+str(index)+'.h5')
    lr = LearningRateScheduler(cyclic_lr(10,0.002, 0.0001))
    log_path = wdir + 'best' + str(num_rounds) + 'depth' + str(depth) + '_' + str(index) + '.log'
    csv_logger = CSVLogger(log_path, append=True)
    
    h = net.fit(training_generator, epochs=num_epochs, validation_data=validation_generator, callbacks=[lr, check, csv_logger])
    
    print("Best validation accuracy: ", np.max(h.history['val_acc']))
    with open(log_path, 'a') as log_file: 
        log_file.write(f"Best validation accuracy: {np.max(h.history['val_acc'])}\n")

    test_generator = DataGenerator(cipher, num_dataset//10, batch_size, num_rounds, diff)
    
    net.load_weights(wdir+'best'+str(num_rounds)+'depth'+str(depth)+'_'+str(index)+'.h5')
    test_loss, test_acc = net.evaluate(test_generator)
    print("Test accuracy: ", test_acc)
    with open(log_path, 'a') as log_file:
        log_file.write(f"Test accuracy: {test_acc}\n")

    return net, h

In [None]:
cipher = cipher_dict['speck3264']
diff=(0x0040,0x0000)
validate_search(index=6, cipher=cipher_dict['speck3264'], num_epochs=50, num_rounds=5, num_blocks=2, num_dataset=10**7, diff=diff, num_filters=32, ds=[64,64], ks=3, depth=10, reg_param=10**-5, loss='mse', wdir="./speck3264/00400000/")