In [None]:
!pip install ff3

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting ff3
  Downloading ff3-1.0.1-py3-none-any.whl (19 kB)
Collecting pycryptodome (from ff3)
  Downloading pycryptodome-3.18.0-cp35-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (2.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.1/2.1 MB[0m [31m26.6 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: pycryptodome, ff3
Successfully installed ff3-1.0.1 pycryptodome-3.18.0


In [None]:


import tensorflow as tf
import numpy as np
from sklearn import datasets as ds
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
import tensorflow as tf
from tensorflow import keras
import numpy as np
import pandas as pd
import re
import csv
import random
import string
import multiprocessing
import itertools
import random
import csv
import statistics
import string
import random
from tensorflow.keras.regularizers import l2
from keras.utils import to_categorical

"""

SPDX-Copyright: Copyright (c) Schoening Consulting, LLC
SPDX-License-Identifier: Apache-2.0
Copyright 2021 Schoening Consulting, LLC

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and limitations under the License.

"""

# Package ff3 implements the FF3-1 format-preserving encryption algorithm/scheme

import logging
import math
from Crypto.Cipher import AES
import string

# The recommendation in Draft SP 800-38G was strengthened to a requirement in Draft SP 800-38G Revision 1:
# the minimum domain size for FF1 and FF3-1 is ones million.

NUM_ROUNDS = 8
BLOCK_SIZE = 16  # aes.BlockSize
TWEAK_LEN = 8  # Original FF3 tweak length
TWEAK_LEN_NEW = 7  # FF3-1 tweak length
HALF_TWEAK_LEN = TWEAK_LEN // 2


def reverse_string(txt):
    """func defined for clarity"""
    return txt[::-1]


"""
FF3 encodes a string within a range of minLen..maxLen. The spec uses an alternating Feistel
with the following parameters:
    128 bit key length
    Cipher Block Chain (CBC-MAC) round function
    64-bit (FF3) or 56-bit (FF3-1)tweak
    eight (8) rounds
    Modulo addition

An encoded string representation of x is in the given integer base, which must be at least 2. The
result uses the lower-case letters 'a' to 'z' for digit values 10 to 35 and upper-case letters 'A' to 'Z' for
digit values 36 to 61.

Instead of specifying the base, an alphabet may be specified as a string of unique characters.
For bases larger than 62, an explicit alphabet is mandatory.

FF3Cipher initializes a new FF3 Cipher object for encryption or decryption with key, tweak and radix parameters. The
default radix is 10, supporting encryption of decimal numbers.

AES ECB is used as the cipher round value for XORing. ECB has a block size of 128 bits (i.e 16 bytes) and is
padded with zeros for blocks smaller than this size. ECB is used only in encrypt mode to generate this XOR value.
A Feistel decryption uses the same ECB encrypt value to decrypt the text. XOR is trivially invertible when you
know two of the arguments.
"""


class FF3Cipher:
    """Class FF3Cipher implements the FF3 format-preserving encryption algorithm.

    If a value of radix between 2 and 62 is specified, then that many characters
    from the base 62 alphabet (digits + lowercase + uppercase latin) are used.
    """
    DOMAIN_MIN = 1_000_000  # 1M required in FF3-1
    BASE62 = string.printable
    BASE62_LEN = len(BASE62)
    RADIX_MAX = 256  # Support 8-bit alphabets for now, requires test cases for larger values

    def __init__(self, key, tweak, radix=BASE62_LEN, ):
        keybytes = bytes.fromhex(key)
        self.tweak = tweak
        self.radix = radix
        if radix <= FF3Cipher.BASE62_LEN:
            self.alphabet = FF3Cipher.BASE62[0:radix]
        else:
            self.alphabet = None

        # Calculate range of supported message lengths [minLen..maxLen]
        # per revised spec, radix^minLength >= 1,000,000.
        self.minLen = math.ceil(math.log(FF3Cipher.DOMAIN_MIN) / math.log(radix))

        # We simplify the specs log[radix](2^96) to 96/log2(radix) using the log base change rule
        self.maxLen = 2 * math.floor(96/math.log2(radix))

        klen = len(keybytes)

        # Check if the key is 128, 192, or 256 bits = 16, 24, or 32 bytes
        if klen not in (16, 24, 32):
            raise ValueError(f'key length is {klen} but must be 128, 192, or 256 bits')

        # While FF3 allows radices in [2, 2^16], commonly useful range is 2..62
        if (radix < 2) or (radix > FF3Cipher.RADIX_MAX):
            raise ValueError("radix must be between 2 and 62, inclusive")

        # Make sure 2 <= minLength <= maxLength
        if (self.minLen < 2) or (self.maxLen < self.minLen):
            raise ValueError("minLen or maxLen invalid, adjust your radix")

        # AES block cipher in ECB mode with the block size derived based on the length of the key
        # Always use the reversed key since Encrypt and Decrypt call ciph expecting that

        self.aesCipher = AES.new(reverse_string(keybytes), AES.MODE_ECB)

    # factory method to create a FF3Cipher object with a custom alphabet
    @staticmethod
    def withCustomAlphabet(key, tweak, alphabet):
        c = FF3Cipher(key, tweak, len(alphabet))
        c.alphabet = alphabet
        return c

    def encrypt(self, plaintext):
        """Encrypts the plaintext string and returns a ciphertext of the same length and format"""
        return self.encrypt_with_tweak(plaintext, self.tweak)

    """
    Feistel structure

            u length |  v length
            A block  |  B block

                C <- modulo function

            B' <- C  |  A' <- B


    Steps:

    Let u = [n/2]
    Let v = n - u
    Let A = X[1..u]
    Let B = X[u+1,n]
    Let T(L) = T[0..31] and T(R) = T[32..63]
    for i <- 0..7 do
        If is even, let m = u and W = T(R) Else let m = v and W = T(L)
        Let P = REV([NUM<radix>(Rev(B))]^12 || W ⊗ REV(i^4)
        Let Y = CIPH(P)
        Let y = NUM<2>(REV(Y))
        Let c = (NUM<radix>(REV(A)) + y) mod radix^m
        Let C = REV(STR<radix>^m(c))
        Let A = B
        Let B = C
    end for
    Return A || B

    * Where REV(X) reverses the order of characters in the character string X

    See spec and examples:

    https://nvlpubs.nist.gov/nistpubs/SpecialPublications/NIST.SP.800-38Gr1-draft.pdf
    https://csrc.nist.gov/CSRC/media/Projects/Cryptographic-Standards-and-Guidelines/documents/examples/FF3samples.pdf
    """

    # EncryptWithTweak allows a parameter tweak instead of the current Cipher's tweak

    def encrypt_with_tweak(self, plaintext, tweak):
        """Encrypts the plaintext string and returns a ciphertext of the same length and format"""
        tweakBytes = bytes.fromhex(tweak)

        n = len(plaintext)

        # Check if message length is within minLength and maxLength bounds
        if (n < self.minLen) or (n > self.maxLen):
            raise ValueError(f"message length {n} is not within min {self.minLen} and max {self.maxLen} bounds")

        # Make sure the given the length of tweak in bits is 56 or 64
        if len(tweakBytes) not in [TWEAK_LEN, TWEAK_LEN_NEW]:
            raise ValueError(f"tweak length {len(tweakBytes)} invalid: tweak must be 56 or 64 bits")

        # Todo: Check message is in current radix

        # Calculate split point
        u = math.ceil(n / 2)
        v = n - u

        # Split the message
        A = plaintext[:u]
        B = plaintext[u:]

        if len(tweakBytes) == TWEAK_LEN_NEW:
            # FF3-1
            tweakBytes = calculate_tweak64_ff3_1(tweakBytes)

        Tl = tweakBytes[:HALF_TWEAK_LEN]
        Tr = tweakBytes[HALF_TWEAK_LEN:]
        logging.debug(f"Tweak: {tweak}, tweakBytes:{tweakBytes.hex()}")

        # Pre-calculate the modulus since it's only one of 2 values,
        # depending on whether i is even or odd

        modU = self.radix ** u
        modV = self.radix ** v
        logging.debug(f"modU: {modU} modV: {modV}")

        # Main Feistel Round, 8 times
        #
        # AES ECB requires the number of bits in the plaintext to be a multiple of
        # the block size. Thus, we pad the input to 16 bytes

        for i in range(NUM_ROUNDS):
            # logging.debug(f"-------- Round {i}")
            # Determine alternating Feistel round side
            if i % 2 == 0:
                m = u
                W = Tr
            else:
                m = v
                W = Tl

            # P is fixed-length 16 bytes
            P = calculate_p(i, self.alphabet, W, B)
            revP = reverse_string(P)

            S = self.aesCipher.encrypt(bytes(revP))

            S = reverse_string(S)
            # logging.debug("S:    ", S.hex())

            y = int.from_bytes(S, byteorder='big')

            # Calculate c
            c = decode_int_r(A,  self.alphabet)

            c = c + y

            if i % 2 == 0:
                c = c % modU
            else:
                c = c % modV

            # logging.debug(f"m: {m} A: {A} c: {c} y: {y}")
            C = encode_int_r(c, self.alphabet, int(m))

            # Final steps
            A = B
            B = C

            # logging.debug(f"A: {A} B: {B}")

        return A + B

    def decrypt(self, ciphertext):
        """
        Decrypts the ciphertext string and returns a plaintext of the same length and format.

        The process of decryption is essentially the same as the encryption process. The  differences
        are  (1)  the  addition  function  is  replaced  by  a  subtraction function that is its
        inverse, and (2) the order of the round indices (i) is reversed.
        """
        return self.decrypt_with_tweak(ciphertext, self.tweak)

    def decrypt_with_tweak(self, ciphertext, tweak):
        """Decrypts the ciphertext string and returns a plaintext of the same length and format"""
        tweakBytes = bytes.fromhex(tweak)

        n = len(ciphertext)

        # Check if message length is within minLength and maxLength bounds
        if (n < self.minLen) or (n > self.maxLen):
            raise ValueError(f"message length {n} is not within min {self.minLen} and max {self.maxLen} bounds")

        # Make sure the given the length of tweak in bits is 56 or 64
        if len(tweakBytes) not in [TWEAK_LEN, TWEAK_LEN_NEW]:
            raise ValueError(f"tweak length {len(tweakBytes)} invalid: tweak must be 8 bytes, or 64 bits")

        # Todo: Check message is in current radix

        # Calculate split point
        u = math.ceil(n/2)
        v = n - u

        # Split the message
        A = ciphertext[:u]
        B = ciphertext[u:]

        if len(tweakBytes) == TWEAK_LEN_NEW:
            # FF3-1
            tweakBytes = calculate_tweak64_ff3_1(tweakBytes)

        Tl = tweakBytes[:HALF_TWEAK_LEN]
        Tr = tweakBytes[HALF_TWEAK_LEN:]
        logging.debug(f"Tweak: {tweak}, tweakBytes:{tweakBytes.hex()}")

        # Pre-calculate the modulus since it's only one of 2 values,
        # depending on whether i is even or odd

        modU = self.radix ** u
        modV = self.radix ** v
        logging.debug(f"modU: {modU} modV: {modV}")

        # Main Feistel Round, 8 times

        for i in reversed(range(NUM_ROUNDS)):

            # logging.debug(f"-------- Round {i}")
            # Determine alternating Feistel round side
            if i % 2 == 0:
                m = u
                W = Tr
            else:
                m = v
                W = Tl

            # P is fixed-length 16 bytes
            P = calculate_p(i, self.alphabet, W, A)
            revP = reverse_string(P)

            S = self.aesCipher.encrypt(bytes(revP))
            S = reverse_string(S)

            # logging.debug("S:    ", S.hex())

            y = int.from_bytes(S, byteorder='big')

            # Calculate c
            c = decode_int_r(B, self.alphabet)

            c = c - y

            if i % 2 == 0:
                c = c % modU
            else:
                c = c % modV

            # logging.debug(f"m: {m} B: {B} c: {c} y: {y}")
            C = encode_int_r(c, self.alphabet, int(m))

            # Final steps
            B = A
            A = C

            # logging.debug(f"A: {A} B: {B}")

        return A + B

def calculate_p(i, alphabet, W, B):
    # P is always 16 bytes
    P = bytearray(BLOCK_SIZE)

    # Calculate P by XORing W, i into the first 4 bytes of P
    # i only requires 1 byte, rest are 0 padding bytes
    # Anything XOR 0 is itself, so only need to XOR the last byte

    P[0] = W[0]
    P[1] = W[1]
    P[2] = W[2]
    P[3] = W[3] ^ int(i)

    # The remaining 12 bytes of P are for rev(B) with padding

    BBytes = decode_int_r(B, alphabet).to_bytes(12, "big")
    # logging.debug(f"B: {B} BBytes: {BBytes.hex()}")

    P[BLOCK_SIZE - len(BBytes):] = BBytes
    return P

def calculate_tweak64_ff3_1(tweak56):
    tweak64 = bytearray(8)
    tweak64[0] = tweak56[0]
    tweak64[1] = tweak56[1]
    tweak64[2] = tweak56[2]
    tweak64[3] = (tweak56[3] & 0xF0)
    tweak64[4] = tweak56[4]
    tweak64[5] = tweak56[5]
    tweak64[6] = tweak56[6]
    tweak64[7] = ((tweak56[3] & 0x0F) << 4)
    return tweak64

def encode_int_r(n, alphabet, length=0):
    """
    Return a string representation of a number in the given base system for 2..62

    The string is left in a reversed order expected by the calling cryptographic function

    examples:
       encode_int_r(10, hexdigits)
        'A'
    """
    base = len(alphabet)
    if (base > FF3Cipher.RADIX_MAX):
        raise ValueError(f"Base {base} is outside range of supported radix 2..{FF3Cipher.RADIX_MAX}")

    x = ''
    while n >= base:
        n, b = divmod(n, base)
        x += alphabet[b]
    x += alphabet[n]

    if len(x) < length:
        x = x.ljust(length, alphabet[0])

    return x


def decode_int_r(astring, alphabet):
    """Decode a Base X encoded string into the number

    Arguments:
    - `astring`: The encoded string
    - `alphabet`: The alphabet to use for decoding
    """
    strlen = len(astring)
    base = len(alphabet)
    num = 0

    idx = 0
    try:
        for char in reversed(astring):
            power = (strlen - (idx + 1))
            num += alphabet.index(char) * (base ** power)
            idx += 1
    except ValueError:
        raise ValueError(f'char {char} not found in alphabet {alphabet}')

    return num

import ast
import string
import random

def generate_diff1(p0):
    x = ast.literal_eval(p0)
    x = x ^ 0x08 # input difference = 0x08

    temp = (x & 0xff)
    p1 = chr(temp)

    temp = ((x >> 8) & 0xff)
    temp = chr(temp)
    p1 = temp + p1

    temp = ((x >> 16) & 0xff)
    temp = chr(temp)
    p1 = temp + p1

    temp = ((x >> 24) & 0xff)
    temp = chr(temp)
    p1 = temp + p1

    return p1

def generate_diff2(p0):
    x = ast.literal_eval(p0)
    x = x ^ 0x01 # input difference = 0x01

    temp = (x & 0xff)
    p1 = chr(temp)

    temp = ((x >> 8) & 0xff)
    temp = chr(temp)
    p1 = temp + p1

    temp = ((x >> 16) & 0xff)
    temp = chr(temp)
    p1 = temp + p1

    temp = ((x >> 24) & 0xff)
    temp = chr(temp)
    p1 = temp + p1

    return p1

def generate_diff3(p0):
    x = ast.literal_eval(p0)
    x = x ^ 0x02 # input difference = 0x01

    temp = (x & 0xff)
    p1 = chr(temp)

    temp = ((x >> 8) & 0xff)
    temp = chr(temp)
    p1 = temp + p1

    temp = ((x >> 16) & 0xff)
    temp = chr(temp)
    p1 = temp + p1

    temp = ((x >> 24) & 0xff)
    temp = chr(temp)
    p1 = temp + p1

    return p1





key = "2DE79D232DF5585D68CE47882AE256D6"
tweak = "CBD09280979564"
c = FF3Cipher(key, tweak)
'''
_LENGTH = 4
string_pool = string.ascii_lowercase
plaintext = ""
for i in range(_LENGTH):
    plaintext = plaintext + random.choice(string_pool)

plaintext0 = plaintext.encode('utf-8')
temp = plaintext0.hex()
temp = '0x' + temp


ciphertext = c.encrypt(plaintext)
decrypted = c.decrypt(ciphertext)

c0= ciphertext.encode('utf-8')
p0= decrypted.encode('utf-8')

bin_c1 = c0.hex()
print(bin (int(bin_c1, base=16)))


# print(f"{plaintext} -> {ciphertext} -> {decrypted}")


p1 = generate_diff(temp)
plaintext1 = p1.encode('utf-8')

ciphertext = c.encrypt(p1)
decrypted = c.decrypt(ciphertext)

c1= ciphertext.encode('utf-8')
p1= decrypted.encode('utf-8')


bin_c1 = c1.hex()
print(bin (int(bin_c1, base=16)))



# print(f"{p1} -> {ciphertext} -> {decrypted}")
'''

'\n_LENGTH = 4\nstring_pool = string.ascii_lowercase\nplaintext = ""\nfor i in range(_LENGTH):\n    plaintext = plaintext + random.choice(string_pool)\n\nplaintext0 = plaintext.encode(\'utf-8\')\ntemp = plaintext0.hex()\ntemp = \'0x\' + temp\n\n\nciphertext = c.encrypt(plaintext)\ndecrypted = c.decrypt(ciphertext)\n\nc0= ciphertext.encode(\'utf-8\')\np0= decrypted.encode(\'utf-8\')\n\nbin_c1 = c0.hex()\nprint(bin (int(bin_c1, base=16)))\n\n\n# print(f"{plaintext} -> {ciphertext} -> {decrypted}")\n\n\np1 = generate_diff(temp)\nplaintext1 = p1.encode(\'utf-8\')\n\nciphertext = c.encrypt(p1)\ndecrypted = c.decrypt(ciphertext)\n\nc1= ciphertext.encode(\'utf-8\')\np1= decrypted.encode(\'utf-8\')\n\n\nbin_c1 = c1.hex()\nprint(bin (int(bin_c1, base=16)))\n\n\n\n# print(f"{p1} -> {ciphertext} -> {decrypted}")\n'

In [None]:
import numpy as np
from os import urandom



X = []

n = 1200000


def Append(k, text):
    for i in range(0, len(text)):
        X[k].append(list(text)[i])

def Append2(k, text):
        X[k].append(text)

#baseline training data generator


l = 0


hex_p0= np.zeros(shape=(n,),dtype=int)
hex_p1= np.zeros(shape=(n,),dtype=int)
hex_p2= np.zeros(shape=(n,),dtype=int)



hex_c0= np.zeros(shape=(n,),dtype=int)
hex_c1= np.zeros(shape=(n,),dtype=int)
hex_c2= np.zeros(shape=(n,),dtype=int)
hex_c3= np.zeros(shape=(n,),dtype=int)




hex_c0c1= np.zeros(shape=(n,),dtype=int)
hex_c0c2= np.zeros(shape=(n,),dtype=int)


Y = np.ones(shape=(n,)) # cipher

k=0
for k in range((2*n)//3, n):
    Y[k] = 2 # random





i=0

print(i)
while(i != n//3):
  Y[i] = 0 # random

  _LENGTH = 4
  string_pool = string.digits #string.ascii_lowercase #string.digits
  plaintext = ""
  plaintext_ran = ""
  for j in range(_LENGTH):
      plaintext = plaintext + random.choice(string_pool)
      #plaintext_ran = plaintext_ran + random.choice(string_pool)



  #hex_p0[i] = int(plaintext.hex(), base=16)
  # bin_p0 = bin (int(hex_p0, base=16))
  ciphertext = c.encrypt(plaintext)
  c0= ciphertext.encode('utf-8')
  hex_c0[i] = int(c0.hex(), base = 16)

  plaintext0 = plaintext.encode('utf-8')
  temp = plaintext0.hex()
  temp = '0x' + temp
  temp = generate_diff2(temp) #이 부분 대신 들어가면 됨 랜덤 안씀

  try:
    ciphertext = c.encrypt(temp)
  except:
    continue
  c0= ciphertext.encode('utf-8')

  hex_c2[i] = int(c0.hex(), base = 16)

  plaintext = plaintext.encode('utf-8')
  temp = temp.encode('utf-8')
  hex_p0[i] = int(plaintext.hex(), base=16)
  hex_p2[i] = int(temp.hex(), base=16)



  #plaintext0 = plaintext.encode('utf-8')
  temp = plaintext0.hex()
  temp = '0x' + temp
  temp = generate_diff1(temp) #이 부분 대신 들어가면 됨 랜덤 안씀

  try:
    ciphertext = c.encrypt(temp)
  except:
    continue
  c0= ciphertext.encode('utf-8')

  hex_c1[i] = int(c0.hex(), base = 16)




  temp = plaintext0.hex()
  temp = '0x' + temp
  temp = generate_diff3(temp) #이 부분 대신 들어가면 됨 랜덤 안씀

  try:
    ciphertext = c.encrypt(temp)
  except:
    continue
  c0= ciphertext.encode('utf-8')

  hex_c3[i] = int(c0.hex(), base = 16)

  temp = plaintext0.hex()
  temp = '0x' + temp
  temp = generate_diff4(temp) #이 부분 대신 들어가면 됨 랜덤 안씀

  try:
    ciphertext = c.encrypt(temp)
  except:
    continue
  c0= ciphertext.encode('utf-8')

  hex_c4[i] = int(c0.hex(), base = 16)



  '''
  ciphertext = c.encrypt(plaintext_ran)
  c0= ciphertext.encode('utf-8')

  hex_c1[i] = int(c0.hex(), base = 16)
  plaintext_ran = plaintext_ran.encode('utf-8')
  hex_p1[i] = int(plaintext_ran.hex(), base=16)
  '''

  #print(hex_c0[i])
  #print(hex_c1[i])
  #print(hex_c2[i])
  #print(hex_p0[i])
  #print(hex_p1[i])
  #print(hex_p2[i])
  #print(type(hex_p1[i]))



  binc0= str(bin(hex_c0[i])).split('0b')[1].zfill(32)
  binc1= str(bin(hex_c1[i])).split('0b')[1].zfill(32)
  binc2= str(bin(hex_c2[i])).split('0b')[1].zfill(32)
  binc3= str(bin(hex_c3[i])).split('0b')[1].zfill(32)

  X.append([l])
  Append(l,binc0)
  Append(l,binc1)
  Append2(l, Y[i])
  l = l+1


  X.append([l])
  Append(l,binc0)
  Append(l,binc2)
  Append2(l, Y[(n//3)+i])
  l = l+1

  X.append([l])
  Append(l,binc0)
  Append(l,binc3)
  Append2(l, Y[((2*n)//3)+i])
  l = l+1






  i = i+1
  ######

f = open("./test.csv", "wt", newline="")
csvwriter = csv.writer(f)
for i in range(0, n):
    csvwriter.writerow(X[i])

print("end")
f.close()


0
end


In [None]:


data = pd.read_csv("./test.csv",header=None).values

labels = []
datas = []
n_bit = 32
div = 2*n_bit +1


#formatting
for i in range(0,len(data)):
  labels.append(data[i][div:])
  datas.append(data[i][1:div])

datas = np.asarray(datas)
labels = np.asarray(labels)

#print(datas)

x_train, x_val, y_train, y_val = train_test_split(datas, labels, test_size=0.4)
x_train, x_test, y_train, y_test = train_test_split(x_train, y_train, test_size=0.05)

y_train = to_categorical(y_train, num_classes=3)
y_val = to_categorical(y_val, num_classes=3)
#y_test = to_categorical(y_test, num_classes=7)


print(len(x_train))
print(len(x_val))
print(len(x_test))
print(y_test)



In [None]:
u=128
inp = tf.keras.layers.Input(shape=(64, ))

x = tf.keras.layers.Dense(u, activation = 'relu',kernel_regularizer=l2(0.001))(inp)
shortcut = x
for i in range(2):
    x = tf.keras.layers.Dense(u, activation = 'relu',kernel_regularizer=l2(0.001))(shortcut)
    x = tf.keras.layers.Dense(u, activation = 'relu',kernel_regularizer=l2(0.001))(x)
    shortcut = tf.keras.layers.Add()([shortcut, x])

out = tf.keras.layers.Dense(3, activation = 'softmax')(shortcut)
model = tf.keras.Model(inputs=inp, outputs=out)

model.summary()

s = 35 * len(x_train) // 32

lr = keras.optimizers.schedules.ExponentialDecay(0.0001, s, 0.001)
opt = keras.optimizers.Adam(lr)

model.compile(optimizer=opt,
              loss='categorical_crossentropy', #binary_crossentropy
              metrics=['accuracy'])

import os
checkpoint_path = "./ckpt/epoch_{epoch:05d}.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)# Create a callback that saves the model's weights
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path, save_weights_only=True, verbose=2)# Train the model with the new callback



In [None]:
model.fit(x_train, y_train, validation_data = (x_val, y_val), epochs=15, batch_size = 32, callbacks = [cp_callback],verbose=2) # 100, 32

In [None]:
predicted_labels = model.predict(np.array(x_test))

y_pred = np.argmax(predicted_labels, axis=1)

sum=0
for i in range(len(x_test)):
  if y_pred[i] == y_test[i]:
    sum +=1

acc = sum/len(x_test)
print(acc)




In [None]:
res = np.array(predicted_labels > 0.5, dtype = int)

res = res.reshape((len(x_test)))
y_test = y_test.reshape((len(y_test)))

total = 0

for i in range(len(res)):
    if res[i] == y_test[i]:
        total += 1

print(total/len(x_test))