In [None]:
# !pip install --quiet neural-structured-learning

In [None]:
import gc
import json
import math
import pickle
import numpy as np
import pandas as pd
from tqdm import tqdm
import matplotlib.pyplot as plt
import tensorflow as tf
import tensorflow_addons as tfa
import neural_structured_learning as nsl

!pip install cached-property
from cached_property import cached_property
from shutil import copyfile

!pip install fastparquet
import fastparquet

!pip install Levenshtein
import Levenshtein as lev
import random

In [None]:
SEED = 42
def seed_everything(seed=SEED):
    random.seed(seed)
    np.random.seed(seed)
    tf.random.set_seed(seed)

seed_everything()

In [None]:
# copy our file into the working directory (make sure it has .py suffix)
copyfile(src = "/kaggle/input/ctc-tpu/CTC_TPU.py", dst = "/kaggle/working//CTC_TPU.py")

# import all our functions
from CTC_TPU import classic_ctc_loss

In [None]:
tpu = None
try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver.connect(tpu="local")
    strategy = tf.distribute.TPUStrategy(tpu)
    print("on TPU")
    print("REPLICAS: ", strategy.num_replicas_in_sync)
except:
    strategy = tf.distribute.get_strategy()

In [None]:
with open ("/kaggle/input/asl-fingerspelling/character_to_prediction_index.json", "r") as f:
    char_to_num = json.load(f)

pad_token = '^'
pad_token_idx = 59

char_to_num[pad_token] = pad_token_idx

num_to_char = {j:i for i,j in char_to_num.items()}
df = pd.read_csv('/kaggle/input/asl-fingerspelling/train.csv')

LIP = [
    61, 185, 40, 39, 37, 0, 267, 269, 270, 409,
    291, 146, 91, 181, 84, 17, 314, 405, 321, 375,
    78, 191, 80, 81, 82, 13, 312, 311, 310, 415,
    95, 88, 178, 87, 14, 317, 402, 318, 324, 308,
]
LPOSE = [13, 15, 17, 19, 21]
RPOSE = [14, 16, 18, 20, 22]
POSE = LPOSE + RPOSE

X = [f'x_right_hand_{i}' for i in range(21)] + [f'x_left_hand_{i}' for i in range(21)] + [f'x_pose_{i}' for i in POSE] + [f'x_face_{i}' for i in LIP]
Y = [f'y_right_hand_{i}' for i in range(21)] + [f'y_left_hand_{i}' for i in range(21)] + [f'y_pose_{i}' for i in POSE] + [f'y_face_{i}' for i in LIP]
Z = [f'z_right_hand_{i}' for i in range(21)] + [f'z_left_hand_{i}' for i in range(21)] + [f'z_pose_{i}' for i in POSE] + [f'z_face_{i}' for i in LIP]

SEL_COLS = X + Y + Z
FRAME_LEN = 128
MAX_PHRASE_LENGTH = 64

LIP_IDX_X   = [i for i, col in enumerate(SEL_COLS)  if  "face" in col and "x" in col]
RHAND_IDX_X = [i for i, col in enumerate(SEL_COLS)  if "right" in col and "x" in col]
LHAND_IDX_X = [i for i, col in enumerate(SEL_COLS)  if  "left" in col and "x" in col]
RPOSE_IDX_X = [i for i, col in enumerate(SEL_COLS)  if  "pose" in col and int(col[-2:]) in RPOSE and "x" in col]
LPOSE_IDX_X = [i for i, col in enumerate(SEL_COLS)  if  "pose" in col and int(col[-2:]) in LPOSE and "x" in col]

LIP_IDX_Y   = [i for i, col in enumerate(SEL_COLS)  if  "face" in col and "y" in col]
RHAND_IDX_Y = [i for i, col in enumerate(SEL_COLS)  if "right" in col and "y" in col]
LHAND_IDX_Y = [i for i, col in enumerate(SEL_COLS)  if  "left" in col and "y" in col]
RPOSE_IDX_Y = [i for i, col in enumerate(SEL_COLS)  if  "pose" in col and int(col[-2:]) in RPOSE and "y" in col]
LPOSE_IDX_Y = [i for i, col in enumerate(SEL_COLS)  if  "pose" in col and int(col[-2:]) in LPOSE and "y" in col]

LIP_IDX_Z   = [i for i, col in enumerate(SEL_COLS)  if  "face" in col and "z" in col]
RHAND_IDX_Z = [i for i, col in enumerate(SEL_COLS)  if "right" in col and "z" in col]
LHAND_IDX_Z = [i for i, col in enumerate(SEL_COLS)  if  "left" in col and "z" in col]
RPOSE_IDX_Z = [i for i, col in enumerate(SEL_COLS)  if  "pose" in col and int(col[-2:]) in RPOSE and "z" in col]
LPOSE_IDX_Z = [i for i, col in enumerate(SEL_COLS)  if  "pose" in col and int(col[-2:]) in LPOSE and "z" in col]

In [None]:
def load_relevant_data_subset(pq_path):
    return pd.read_parquet(pq_path, columns=SEL_COLS)

file_id = df.file_id.iloc[0]
inpdir = "/kaggle/input/asl-fingerspelling/train_landmarks"
pqfile = f"{inpdir}/{file_id}.parquet"
seq_refs = df.loc[df.file_id == file_id]
seqs = load_relevant_data_subset(pqfile)

seq_id = seq_refs.sequence_id.iloc[0]
frames = seqs.iloc[seqs.index == seq_id]
phrase = str(df.loc[df.sequence_id == seq_id].phrase.iloc[0])

In [None]:
def interp1d_(x, target_len, method='random'):

    length = tf.shape(x)[1]

    target_len = tf.maximum(1,target_len)
    if method == 'random':
        random = tf.random.uniform(())
        if random < 0.33:
            x = tf.image.resize(x, (target_len,tf.shape(x)[1]),'bilinear')


        elif random < 0.5:
            x = tf.image.resize(x, (target_len,tf.shape(x)[1]),'bicubic')


        else:
            x = tf.image.resize(x, (target_len,tf.shape(x)[1]),'nearest')

    else:
        x = tf.image.resize(x, (target_len,tf.shape(x)[1]), method)

    return x



def personnal_resample(x, rate=(0.7,1.1)):

    x = x[..., tf.newaxis]

    rate = tf.random.uniform((), rate[0], rate[1])

    length = tf.shape(x)[0]

    new_size = tf.cast(rate*tf.cast(length,tf.float32), tf.int32)

    new_x = interp1d_(x, new_size)
    new_x = tf.squeeze(new_x, axis=-1)
    return new_x



def personnal_temporal_crop(x, max_percent_crop=0.075):

    l = tf.shape(x)[0]

    crop_from = tf.random.uniform((), 0, tf.cast(max_percent_crop*tf.cast(l, dtype=tf.float32), dtype=tf.int32), dtype=tf.int32)

    crop_to = tf.random.uniform((), 0, tf.cast(max_percent_crop*tf.cast(l, dtype=tf.float32), dtype=tf.int32), dtype=tf.int32)

    x_new = x[crop_from:-crop_to]

    return x_new


def personnal_temporal_mask_2(x, min_mask=0.4, max_mask = 0.5, mask_value=float('nan')):#min a 0.2 et max 0.3
    l = tf.shape(x)[0]
    mask_percent = tf.random.uniform((), *(min_mask, max_mask ))
    mask_size = tf.cast(tf.cast(l, tf.float32) * mask_percent, tf.int32)

    indices= tf.random.uniform(shape=[mask_size], maxval=l, dtype=tf.int32)
    indices = indices[..., None]
    x = tf.tensor_scatter_nd_update(x, indices ,tf.fill([mask_size,276], mask_value)) # 164 == NBcolumn
    return x

def personnal_spatial_mask(x, size=(0.2,0.4), mask_value=float('nan')):
    mask_offset = tf.random.uniform(()) # try to use mean and std of the space value of x
    mask_size = tf.random.uniform((), *size)
    mask_column = (mask_offset<x) & (x < mask_offset + mask_size)
    x = tf.where(mask_column, mask_value, x)
    return x

def personnal_spatial_mask_2(x, min_mask=0.4, max_mask = 0.5, mask_value=float('NaN')):

    l_columns = tf.shape(x)[-1]
    mask_percent = tf.random.uniform((), *(min_mask, max_mask ))
    mask_size = tf.cast(tf.cast(l_columns , tf.float32) * mask_percent, tf.int32)
    replace_column_indices = tf.random.shuffle(tf.range(l_columns))[:mask_size]
    # Create a mask to identify the columns to replace with NaN
    mask = tf.reduce_sum(tf.one_hot(replace_column_indices, l_columns, dtype=tf.float32), axis=0)
    # Apply the mask to the tensor by replacing with NaN values
    replaced_tensor = tf.where(mask > 0, mask_value, x)
    return replaced_tensor


def flip_lr_2(data):

    x = data[..., :92]
    y = data[..., 92:2*92]
    z = data[..., 2*92:]
    x = -x
    new_x = tf.concat([x, y, z],axis=-1)
    return new_x


In [None]:
def augment_mix(x):
    flags = []
    # shuffle aug with random system
    if tf.random.uniform(()) < 0.7:
        flags.append("resample")
        x = personnal_resample(x)
    if tf.random.uniform(()) < 0.7:
        flags.append("temporal")
        x = personnal_temporal_mask_2(x)
    if tf.random.uniform(()) < 0.7:
        flags.append("spatial")
        x = personnal_spatial_mask_2(x)
    if tf.random.uniform(()) < 0.7:
        flags.append("flip")
        x = flip_lr_2(x)
    log = []
    for flag in flags :
        if isinstance(flag,str):
            log.append(flag)
    #aucune aug n'a été faite
    if len(log)==0:
        x = personnal_resample(x)
        x = personnal_temporal_mask_2(x)
    #une aug a été faite (au moins 2)
    if len(log)==1:
        if log[0]=="resample":
            x = personnal_spatial_mask_2(x)
        if log[0]=="temporal":
            x = flip_lr_2(x)
        if log[0]=="spatial":
            x = personnal_temporal_mask_2(x)
        if log[0]=="flip":
            x = personnal_resample(x)

    return x

In [None]:
def augment(x):

    # shuffle aug with random system


    if tf.random.uniform(()) < 0.3:
        x = personnal_resample(x)
    else :
        if tf.random.uniform(()) < 0.5 :
                    x = personnal_temporal_mask_2(x)
        else :
            if tf.random.uniform(()) < 0.5 :
                    x = personnal_spatial_mask_2(x)
            else :
                x = flip_lr_2(x)

    return x

In [None]:
@tf.function()
def tf_nan_mean(x, axis=0, keepdims=False):

    return tf.reduce_sum(tf.where(tf.math.is_nan(x), tf.zeros_like(x), x), axis=axis, keepdims=keepdims) / tf.reduce_sum(tf.where(tf.math.is_nan(x), tf.zeros_like(x), tf.ones_like(x)), axis=axis, keepdims=keepdims)


@tf.function()
def tf_nan_std(x, center=None, axis=0, keepdims=False):

    if center is None:

        center = tf_nan_mean(x, axis=axis,  keepdims=True)

    d = x - center

    return tf.math.sqrt(tf_nan_mean(d * d, axis=axis, keepdims=keepdims))

@tf.function()
def resize_pad(x):
    if tf.shape(x)[0] < FRAME_LEN:
        x = tf.pad(x, ([[0, FRAME_LEN-tf.shape(x)[0]], [0, 0], [0, 0]]), constant_values=float(-1000.0))
    else:
        x = tf.image.resize(x, (FRAME_LEN, tf.shape(x)[1]))
    return x

@tf.function(jit_compile=True)
def pre_process0(x):
    lip_x = tf.gather(x, LIP_IDX_X, axis=1)
    lip_y = tf.gather(x, LIP_IDX_Y, axis=1)
    lip_z = tf.gather(x, LIP_IDX_Z, axis=1)

    rhand_x = tf.gather(x, RHAND_IDX_X, axis=1)
    rhand_y = tf.gather(x, RHAND_IDX_Y, axis=1)
    rhand_z = tf.gather(x, RHAND_IDX_Z, axis=1)

    lhand_x = tf.gather(x, LHAND_IDX_X, axis=1)
    lhand_y = tf.gather(x, LHAND_IDX_Y, axis=1)
    lhand_z = tf.gather(x, LHAND_IDX_Z, axis=1)

    rpose_x = tf.gather(x, RPOSE_IDX_X, axis=1)
    rpose_y = tf.gather(x, RPOSE_IDX_Y, axis=1)
    rpose_z = tf.gather(x, RPOSE_IDX_Z, axis=1)

    lpose_x = tf.gather(x, LPOSE_IDX_X, axis=1)
    lpose_y = tf.gather(x, LPOSE_IDX_Y, axis=1)
    lpose_z = tf.gather(x, LPOSE_IDX_Z, axis=1)

    lip   = tf.concat([lip_x[..., tf.newaxis], lip_y[..., tf.newaxis], lip_z[..., tf.newaxis]], axis=-1)
    rhand = tf.concat([rhand_x[..., tf.newaxis], rhand_y[..., tf.newaxis], rhand_z[..., tf.newaxis]], axis=-1)
    lhand = tf.concat([lhand_x[..., tf.newaxis], lhand_y[..., tf.newaxis], lhand_z[..., tf.newaxis]], axis=-1)
    rpose = tf.concat([rpose_x[..., tf.newaxis], rpose_y[..., tf.newaxis], rpose_z[..., tf.newaxis]], axis=-1)
    lpose = tf.concat([lpose_x[..., tf.newaxis], lpose_y[..., tf.newaxis], lpose_z[..., tf.newaxis]], axis=-1)

    hand = tf.concat([rhand, lhand], axis=1)
    hand = tf.where(tf.math.is_nan(hand), 0.0, hand)
    mask = tf.math.not_equal(tf.reduce_sum(hand, axis=[1, 2]), 0.0)

    lip = lip[mask]
    rhand = rhand[mask]
    lhand = lhand[mask]
    rpose = rpose[mask]
    lpose = lpose[mask]

    return lip, rhand, lhand, rpose, lpose

# @tf.function()
# def pre_process1(lip, rhand, lhand, rpose, lpose):
#     lip   = (lip - tf_nan_mean(lip, keepdims=True)) / (tf_nan_std(lip, keepdims=True))
#     rhand = (rhand - tf_nan_mean(rhand, keepdims=True)) / (tf_nan_std(rhand, keepdims=True))
#     lhand = (lhand - tf_nan_mean(lhand, keepdims=True)) / (tf_nan_std(lhand, keepdims=True))
#     rpose = (rpose - tf_nan_mean(rpose, keepdims=True)) / (tf_nan_std(rpose, keepdims=True))
#     lpose = (lpose - tf_nan_mean(lpose, keepdims=True)) / (tf_nan_std(lpose, keepdims=True))

#     x = tf.concat([lip, rhand, lhand, rpose, lpose], axis=1)
#     x = resize_pad(x)
#     x = tf.unstack(x, axis=-1)
#     x = tf.concat(x, axis=-1)

#     x = tf.where(tf.math.is_nan(x), 0.0, x)
#     return x

#Ajout data aug en plus
@tf.function()
def pre_process1_aug(lip, rhand, lhand, rpose, lpose):

    #ajout de la data_aug avant la normalisation
    x = tf.concat([lip, rhand, lhand, rpose, lpose], axis=1)
    x_new = tf.unstack(x,axis=-1)
    x_new = tf.concat(x_new,axis=-1)

    x_aug = augment(x_new)
    xx_aug = x_aug[..., :92]
    y_aug = x_aug[..., 92:2*92]
    z_aug = x_aug[..., 2*92:]
    x_aug = tf.stack([xx_aug, y_aug, z_aug], axis=-1)

    lip_aug, rhand_aug, lhand_aug, rpose_aug, lpose_aug = tf.split(x_aug, num_or_size_splits=[lip.shape[1], rhand.shape[1], lhand.shape[1], rpose.shape[1], lpose.shape[1]], axis=1)
    #normalization
    lip_aug   = (lip_aug - tf_nan_mean(lip_aug, keepdims=True)) / (tf_nan_std(lip_aug, keepdims=True))
    rhand_aug = (rhand_aug - tf_nan_mean(rhand_aug, keepdims=True)) / (tf_nan_std(rhand_aug, keepdims=True))
    lhand_aug = (lhand_aug - tf_nan_mean(lhand_aug, keepdims=True)) / (tf_nan_std(lhand_aug, keepdims=True))
    rpose_aug = (rpose_aug - tf_nan_mean(rpose_aug, keepdims=True)) / (tf_nan_std(rpose_aug, keepdims=True))
    lpose_aug = (lpose_aug - tf_nan_mean(lpose_aug, keepdims=True)) / (tf_nan_std(lpose_aug, keepdims=True))


    x_aug = tf.concat([lip_aug, rhand_aug, lhand_aug, rpose_aug, lpose_aug], axis=1)
    x_aug = resize_pad(x_aug)
    x_aug = tf.unstack(x_aug, axis=-1)
    x_aug = tf.concat(x_aug, axis=-1)

    x_aug = tf.where(tf.math.is_nan(x_aug), 0.0, x_aug)
    return x_aug

@tf.function()
def pre_process1_aug_mix(lip, rhand, lhand, rpose, lpose):

    #ajout de la data_aug avant la normalisation
    x = tf.concat([lip, rhand, lhand, rpose, lpose], axis=1)
    x_new = tf.unstack(x,axis=-1)
    x_new = tf.concat(x_new,axis=-1)

    x_aug = augment_mix(x_new)
    xx_aug = x_aug[..., :92]
    y_aug = x_aug[..., 92:2*92]
    z_aug = x_aug[..., 2*92:]
    x_aug = tf.stack([xx_aug, y_aug, z_aug], axis=-1)

    lip_aug, rhand_aug, lhand_aug, rpose_aug, lpose_aug = tf.split(x_aug, num_or_size_splits=[lip.shape[1], rhand.shape[1], lhand.shape[1], rpose.shape[1], lpose.shape[1]], axis=1)
    #normalization
    lip_aug   = (lip_aug - tf_nan_mean(lip_aug, keepdims=True)) / (tf_nan_std(lip_aug, keepdims=True))
    rhand_aug = (rhand_aug - tf_nan_mean(rhand_aug, keepdims=True)) / (tf_nan_std(rhand_aug, keepdims=True))
    lhand_aug = (lhand_aug - tf_nan_mean(lhand_aug, keepdims=True)) / (tf_nan_std(lhand_aug, keepdims=True))
    rpose_aug = (rpose_aug - tf_nan_mean(rpose_aug, keepdims=True)) / (tf_nan_std(rpose_aug, keepdims=True))
    lpose_aug = (lpose_aug - tf_nan_mean(lpose_aug, keepdims=True)) / (tf_nan_std(lpose_aug, keepdims=True))


    x_aug = tf.concat([lip_aug, rhand_aug, lhand_aug, rpose_aug, lpose_aug], axis=1)
    x_aug = resize_pad(x_aug)
    x_aug = tf.unstack(x_aug, axis=-1)
    x_aug = tf.concat(x_aug, axis=-1)

    x_aug = tf.where(tf.math.is_nan(x_aug), 0.0, x_aug)
    return x_aug

INPUT_SHAPE = [128, 276]

In [None]:
#NEW
def decode_fn(record_bytes):
    schema = {
        "lip": tf.io.VarLenFeature(tf.float32),
        "rhand": tf.io.VarLenFeature(tf.float32),
        "lhand": tf.io.VarLenFeature(tf.float32),
        "rpose": tf.io.VarLenFeature(tf.float32),
        "lpose": tf.io.VarLenFeature(tf.float32),
        "phrase": tf.io.VarLenFeature(tf.int64)
    }
    x = tf.io.parse_single_example(record_bytes, schema)
    lip = tf.reshape(tf.sparse.to_dense(x["lip"]), (-1, 40, 3))
    rhand = tf.reshape(tf.sparse.to_dense(x["rhand"]), (-1, 21, 3))
    lhand = tf.reshape(tf.sparse.to_dense(x["lhand"]), (-1, 21, 3))
    rpose = tf.reshape(tf.sparse.to_dense(x["rpose"]), (-1, 5, 3))
    lpose = tf.reshape(tf.sparse.to_dense(x["lpose"]), (-1, 5, 3))
    phrase = tf.sparse.to_dense(x["phrase"])

    return lip, rhand, lhand, rpose, lpose, phrase

def pre_process_fn(lip, rhand, lhand, rpose, lpose, phrase):
    phrase = tf.pad(phrase, [[0, MAX_PHRASE_LENGTH-tf.shape(phrase)[0]]], constant_values=pad_token_idx)
    return pre_process1(lip, rhand, lhand, rpose, lpose), phrase

def pre_process_fn_aug(lip, rhand, lhand, rpose, lpose, phrase):
    phrase = tf.pad(phrase, [[0, MAX_PHRASE_LENGTH-tf.shape(phrase)[0]]], constant_values=pad_token_idx)
    return pre_process1_aug(lip, rhand, lhand, rpose, lpose), phrase

def pre_process_fn_aug_mix(lip, rhand, lhand, rpose, lpose, phrase):
    phrase = tf.pad(phrase, [[0, MAX_PHRASE_LENGTH-tf.shape(phrase)[0]]], constant_values=pad_token_idx)
    return pre_process1_aug_mix(lip, rhand, lhand, rpose, lpose), phrase

tffiles = [f"/kaggle/input/personnal-data-3/tfds/{file_id}.tfrecord" for file_id in df.file_id.unique()]
val_len = int(0.05 * len(tffiles))
print('val_len: ' + str(val_len))
train_batch_size = 1024 # was 32
val_batch_size = 1024 # was 32

# train_dataset =  tf.data.TFRecordDataset(tffiles).prefetch(tf.data.AUTOTUNE).shuffle(5000).map(decode_fn, num_parallel_calls=tf.data.AUTOTUNE).map(pre_process_fn, num_parallel_calls=tf.data.AUTOTUNE).batch(train_batch_size).prefetch(tf.data.AUTOTUNE)
# val_dataset =  tf.data.TFRecordDataset(tffiles[:val_len]).prefetch(tf.data.AUTOTUNE).map(decode_fn, num_parallel_calls=tf.data.AUTOTUNE).map(pre_process_fn, num_parallel_calls=tf.data.AUTOTUNE).batch(train_batch_size).prefetch(tf.data.AUTOTUNE)

train_dataset_pre =  tf.data.TFRecordDataset(tffiles[val_len:]).prefetch(tf.data.AUTOTUNE).map(decode_fn, num_parallel_calls=tf.data.AUTOTUNE).map(pre_process_fn, num_parallel_calls=tf.data.AUTOTUNE)
val_dataset_pre =  tf.data.TFRecordDataset(tffiles[:val_len]).prefetch(tf.data.AUTOTUNE).map(decode_fn, num_parallel_calls=tf.data.AUTOTUNE).map(pre_process_fn, num_parallel_calls=tf.data.AUTOTUNE)

# aug_dataset_pre =  tf.data.TFRecordDataset(tffiles[val_len:]).prefetch(tf.data.AUTOTUNE).map(decode_fn, num_parallel_calls=tf.data.AUTOTUNE).map(pre_process_fn_aug, num_parallel_calls=tf.data.AUTOTUNE)

# aug_mix_dataset_pre =  tf.data.TFRecordDataset(tffiles[val_len:]).prefetch(tf.data.AUTOTUNE).map(decode_fn, num_parallel_calls=tf.data.AUTOTUNE).map(pre_process_fn_aug_mix, num_parallel_calls=tf.data.AUTOTUNE)

#batch = next(iter(val_dataset))
# batch[0].shape, batch[1].shape


In [None]:
#NEW avec aug
val_items = [x for x in val_dataset_pre]
val_items_X = [x[0] for x in val_items]
val_items_y = [tf.cast(x[1], dtype = tf.int32) for x in val_items]

# #aug classique
# aug_items = [x for x in aug_dataset_pre]
# aug_ratio = int(len(aug_items)*0.5)
# #aug mix
# aug_mix_items = [x for x in aug_mix_dataset_pre]
# aug_mix_ratio = int(len(aug_mix_items)*0.5)

train_items = [x for x in train_dataset_pre]
# print(f"nombre de data simple {len(train_items)}, de data aug simple {len(aug_items[:aug_ratio])}, de data aug mix {len(aug_mix_items[:aug_mix_ratio])}")
# train_items += aug_items[:aug_ratio] #ajout de 45% de data aug
# train_items += aug_mix_items[:aug_mix_ratio] #ajout de 45% de data aug

print(f"nombre data total {len(train_items)}")
train_items_X = [x[0] for x in train_items]
train_items_y = [tf.cast(x[1], dtype = tf.int32) for x in train_items]

#Création du dataset a partir de X et Y dans un bon format
val_dataset = tf.data.Dataset.from_tensor_slices((val_items_X,val_items_y)).prefetch(tf.data.AUTOTUNE).batch(val_batch_size, drop_remainder=True).prefetch(tf.data.AUTOTUNE)

#train_dataset = tf.data.Dataset.from_tensor_slices((train_items_X, train_items_y)).prefetch(tf.data.AUTOTUNE).shuffle(60000).batch(train_batch_size, drop_remainder=True).prefetch(tf.data.AUTOTUNE)
train_dataset = tf.data.Dataset.from_tensor_slices((train_items_X,train_items_y)).prefetch(tf.data.AUTOTUNE).shuffle(len(train_items)).repeat(2).batch(train_batch_size, drop_remainder=True).prefetch(tf.data.AUTOTUNE)

batch = next(iter(val_dataset))
batch[0].shape, batch[1].shape

In [None]:
import typing
seed_everything()
def shape_list(x, out_type=tf.int32):
    """Deal with dynamic shape in tensorflow cleanly."""
    static = x.shape.as_list()
    dynamic = tf.shape(x, out_type=out_type)
    return [dynamic[i] if s is None else s for i, s in enumerate(static)]


def get_shape_invariants(tensor):
    shapes = shape_list(tensor)
    return tf.TensorShape([i if isinstance(i, int) else None for i in shapes])


def get_float_spec(tensor):
    shape = get_shape_invariants(tensor)
    return tf.TensorSpec(shape, dtype=tf.float32)

class GLU(tf.keras.layers.Layer):
    def __init__(
        self,
        axis=-1,
        name="glu_activation",
        **kwargs,
    ):
        super(GLU, self).__init__(name=name, **kwargs)
        self.axis = axis

    def call(
        self,
        inputs,
        **kwargs,
    ):
        a, b = tf.split(inputs, 2, axis=self.axis)
        b = tf.nn.sigmoid(b)
        return tf.multiply(a, b)

    def get_config(self):
        conf = super(GLU, self).get_config()
        conf.update({"axis": self.axis})
        return conf


class MultiHeadAttention(tf.keras.layers.Layer):
    def __init__(
        self,
        num_heads,
        head_size,
        output_size: int = None,
        dropout: float = 0.0,
        use_projection_bias: bool = True,
        return_attn_coef: bool = False,
        kernel_initializer: typing.Union[str, typing.Callable] = tf.keras.initializers.glorot_uniform(seed=SEED),
        kernel_regularizer: typing.Union[str, typing.Callable] = None,
        kernel_constraint: typing.Union[str, typing.Callable] = None,
        bias_initializer: typing.Union[str, typing.Callable] = "zeros",
        bias_regularizer: typing.Union[str, typing.Callable] = None,
        bias_constraint: typing.Union[str, typing.Callable] = None,
        **kwargs,
    ):
        super(MultiHeadAttention, self).__init__(**kwargs)

        if output_size is not None and output_size < 1:
            raise ValueError("output_size must be a positive number")

        self.kernel_initializer = tf.keras.initializers.get(kernel_initializer)
        self.kernel_regularizer = tf.keras.regularizers.get(kernel_regularizer)
        self.kernel_constraint = tf.keras.constraints.get(kernel_constraint)
        self.bias_initializer = tf.keras.initializers.get(bias_initializer)
        self.bias_regularizer = tf.keras.regularizers.get(bias_regularizer)
        self.bias_constraint = tf.keras.constraints.get(bias_constraint)

        self.head_size = head_size
        self.num_heads = num_heads
        self.output_size = output_size
        self.use_projection_bias = use_projection_bias
        self.return_attn_coef = return_attn_coef

        self.dropout = tf.keras.layers.Dropout(dropout, name="dropout")
        self._droput_rate = dropout
        self.supports_masking = True # RAJOUT

    def build(
        self,
        input_shape,
    ):
        num_query_features = input_shape[0][-1]
        num_key_features = input_shape[1][-1]
        num_value_features = input_shape[2][-1] if len(input_shape) > 2 else num_key_features
        output_size = self.output_size if self.output_size is not None else num_value_features
        self.query_kernel = self.add_weight(
            name="query_kernel",
            shape=[self.num_heads, num_query_features, self.head_size],
            initializer=self.kernel_initializer,
            regularizer=self.kernel_regularizer,
            constraint=self.kernel_constraint,
        )
        self.key_kernel = self.add_weight(
            name="key_kernel",
            shape=[self.num_heads, num_key_features, self.head_size],
            initializer=self.kernel_initializer,
            regularizer=self.kernel_regularizer,
            constraint=self.kernel_constraint,
        )
        self.value_kernel = self.add_weight(
            name="value_kernel",
            shape=[self.num_heads, num_value_features, self.head_size],
            initializer=self.kernel_initializer,
            regularizer=self.kernel_regularizer,
            constraint=self.kernel_constraint,
        )
        self.projection_kernel = self.add_weight(
            name="projection_kernel",
            shape=[self.num_heads, self.head_size, output_size],
            initializer=self.kernel_initializer,
            regularizer=self.kernel_regularizer,
            constraint=self.kernel_constraint,
        )
        if self.use_projection_bias:
            self.projection_bias = self.add_weight(
                name="projection_bias",
                shape=[output_size],
                initializer=self.bias_initializer,
                regularizer=self.bias_regularizer,
                constraint=self.bias_constraint,
            )
        else:
            self.projection_bias = None

    def call_qkv(
        self,
        query,
        key,
        value,
        training=None,
    ):
        # verify shapes
        if key.shape[-2] != value.shape[-2]:
            raise ValueError(
                "the number of elements in 'key' must be equal to " "the same as the number of elements in 'value'"
            )
        # Linear transformations
        query = tf.einsum("...NI,HIO->...NHO", query, self.query_kernel)
        key = tf.einsum("...MI,HIO->...MHO", key, self.key_kernel)
        value = tf.einsum("...MI,HIO->...MHO", value, self.value_kernel)

        return query, key, value

    def call_attention(
        self,
        query,
        key,
        value,
        logits,
        training=None,
        mask=None,
        attention_mask=None,
    ):
        # mask = attention mask with shape [B, Tquery, Tkey] with 1 is for positions we want to attend, 0 for masked
        if attention_mask is not None:
            if len(attention_mask.shape) < 2: #was written mask
                raise ValueError("'mask' must have at least 2 dimensions")
            if query.shape[-3] != attention_mask.shape[-2]:
                raise ValueError("mask's second to last dimension must be equal to " "the number of elements in 'query'")
            if key.shape[-3] != attention_mask.shape[-1]:
                raise ValueError("mask's last dimension must be equal to the number of elements in 'key'")
        # apply mask
        if attention_mask is not None:
            attention_mask = tf.cast(attention_mask, tf.float32)

            # possibly expand on the head dimension so broadcasting works
            if len(attention_mask.shape) != len(logits.shape):
                attention_mask = tf.expand_dims(attention_mask, -3)

            logits += -10e9 * (1.0 - attention_mask)

        attn_coef = tf.nn.softmax(logits)

        # attention dropout
        attn_coef_dropout = self.dropout(attn_coef, training=training)

        # attention * value
        multihead_output = tf.einsum("...HNM,...MHI->...NHI", attn_coef_dropout, value)

        # Run the outputs through another linear projection layer. Recombining heads
        # is automatically done.
        output = tf.einsum("...NHI,HIO->...NO", multihead_output, self.projection_kernel)

        if self.projection_bias is not None:
            output += self.projection_bias

        return output, attn_coef

    def call(
        self,
        inputs,
        training=None,
        mask=None,
        attention_mask=None,
        **kwargs,
    ):
        query, key, value = inputs

        query, key, value = self.call_qkv(query, key, value, training=training)

        # Scale dot-product, doing the division to either query or key
        # instead of their product saves some computation
        depth = tf.constant(self.head_size, dtype=tf.float32)
        query /= tf.sqrt(depth)

        # Calculate dot product attention
        logits = tf.einsum("...NHO,...MHO->...HNM", query, key)

        output, attn_coef = self.call_attention(query, key, value, logits, training=training, mask=mask, attention_mask=attention_mask)

        if self.return_attn_coef:
            return output, attn_coef
        else:
            return output

    def compute_output_shape(
        self,
        input_shape,
    ):
        num_value_features = input_shape[2][-1] if len(input_shape) > 2 else input_shape[1][-1]
        output_size = self.output_size if self.output_size is not None else num_value_features

        output_shape = input_shape[0][:-1] + (output_size,)

        if self.return_attn_coef:
            num_query_elements = input_shape[0][-2]
            num_key_elements = input_shape[1][-2]
            attn_coef_shape = input_shape[0][:-2] + (
                self.num_heads,
                num_query_elements,
                num_key_elements,
            )

            return output_shape, attn_coef_shape
        else:
            return output_shape

    def get_config(self):
        config = super().get_config()

        config.update(
            head_size=self.head_size,
            num_heads=self.num_heads,
            output_size=self.output_size,
            dropout=self._droput_rate,
            use_projection_bias=self.use_projection_bias,
            return_attn_coef=self.return_attn_coef,
            kernel_initializer=tf.keras.initializers.serialize(self.kernel_initializer),
            kernel_regularizer=tf.keras.regularizers.serialize(self.kernel_regularizer),
            kernel_constraint=tf.keras.constraints.serialize(self.kernel_constraint),
            bias_initializer=tf.keras.initializers.serialize(self.bias_initializer),
            bias_regularizer=tf.keras.regularizers.serialize(self.bias_regularizer),
            bias_constraint=tf.keras.constraints.serialize(self.bias_constraint),
        )

        return config


class RelPositionMultiHeadAttention(MultiHeadAttention):
    def build(
        self,
        input_shape,
    ):
        num_pos_features = input_shape[-1][-1]
        self.pos_kernel = self.add_weight(
            name="pos_kernel",
            shape=[self.num_heads, num_pos_features, self.head_size],
            initializer=self.kernel_initializer,
            regularizer=self.kernel_regularizer,
            constraint=self.kernel_constraint,
        )
        self.pos_bias_u = self.add_weight(
            name="pos_bias_u",
            shape=[self.num_heads, self.head_size],
            regularizer=self.kernel_regularizer,
            initializer=self.kernel_initializer,
            constraint=self.kernel_constraint,
        )
        self.pos_bias_v = self.add_weight(
            name="pos_bias_v",
            shape=[self.num_heads, self.head_size],
            regularizer=self.kernel_regularizer,
            initializer=self.kernel_initializer,
            constraint=self.kernel_constraint,
        )
        super(RelPositionMultiHeadAttention, self).build(input_shape[:-1])

    @staticmethod
    def relative_shift(x):
        x_shape = tf.shape(x)
        x = tf.pad(x, [[0, 0], [0, 0], [0, 0], [1, 0]])
        x = tf.reshape(x, [x_shape[0], x_shape[1], x_shape[3] + 1, x_shape[2]])
        x = tf.reshape(x[:, :, 1:, :], x_shape)
        return x

    def call(
        self,
        inputs,
        training=None,
        mask=None,
        attention_mask=None,
        **kwargs,
    ):
        query, key, value, pos = inputs

        query, key, value = self.call_qkv(query, key, value, training=training)

        pos = tf.einsum("...MI,HIO->...MHO", pos, self.pos_kernel)

        query_with_u = query + self.pos_bias_u
        query_with_v = query + self.pos_bias_v

        logits_with_u = tf.einsum("...NHO,...MHO->...HNM", query_with_u, key)
        logits_with_v = tf.einsum("...NHO,...MHO->...HNM", query_with_v, pos)
        logits_with_v = self.relative_shift(logits_with_v)

        logits = logits_with_u + logits_with_v[:, :, :, : tf.shape(logits_with_u)[3]]

        depth = tf.constant(self.head_size, dtype=tf.float32)
        logits /= tf.sqrt(depth)

        output, attn_coef = self.call_attention(query, key, value, logits, training=training, mask=mask, attention_mask=attention_mask)

        if self.return_attn_coef:
            return output, attn_coef
        else:
            return output


class PositionalEncoding(tf.keras.layers.Layer):
    def __init__(
        self,
        alpha: int = 1,
        beta: int = 0,
        name="positional_encoding",
        **kwargs,
    ):
        super().__init__(trainable=False, name=name, **kwargs)
        self.alpha = alpha
        self.beta = beta
        self.supports_masking = True # RAJOUT

    def build(
        self,
        input_shape,
    ):
        dmodel = input_shape[-1]
        assert dmodel % 2 == 0, f"Input last dim must be even: {dmodel}"

    @staticmethod
    def encode(
        max_len,
        dmodel,
    ):
        pos = tf.expand_dims(tf.range(max_len - 1, -1, -1.0, dtype=tf.float32), axis=1)
        index = tf.expand_dims(tf.range(0, dmodel, dtype=tf.float32), axis=0)

        pe = pos * (1 / tf.pow(10000.0, (2 * (index // 2)) / dmodel))

        # Sin cos will be [max_len, size // 2]
        # we add 0 between numbers by using padding and reshape
        sin = tf.pad(tf.expand_dims(tf.sin(pe[:, 0::2]), -1), [[0, 0], [0, 0], [0, 1]], mode="CONSTANT", constant_values=0)
        sin = tf.reshape(sin, [max_len, dmodel])
        cos = tf.pad(tf.expand_dims(tf.cos(pe[:, 1::2]), -1), [[0, 0], [0, 0], [1, 0]], mode="CONSTANT", constant_values=0)
        cos = tf.reshape(cos, [max_len, dmodel])
        # Then add sin and cos, which results in [time, size]
        pe = tf.add(sin, cos)
        return tf.expand_dims(pe, axis=0)  # [1, time, size]

    def call(
        self,
        inputs,
        **kwargs,
    ):
        # inputs shape [B, T, V]
        _, max_len, dmodel = shape_list(inputs)
        pe = self.encode(max_len * self.alpha + self.beta, dmodel)
        return tf.cast(pe, dtype=inputs.dtype)

    def get_config(self):
        conf = super().get_config()
        conf.update({"alpha": self.alpha, "beta": self.beta})
        return conf


class PositionalEncodingConcat(PositionalEncoding):
    def build(
        self,
        input_shape,
    ):
        dmodel = input_shape[-1]
        assert dmodel % 2 == 0, f"Input last dim must be even: {dmodel}"

    @staticmethod
    def encode(
        max_len,
        dmodel,
    ):
        pos = tf.range(max_len - 1, -1, -1.0, dtype=tf.float32)

        index = tf.range(0, dmodel, 2.0, dtype=tf.float32)
        index = 1 / tf.pow(10000.0, (index / dmodel))

        sinusoid = tf.einsum("i,j->ij", pos, index)
        pos = tf.concat([tf.sin(sinusoid), tf.cos(sinusoid)], axis=-1)

        return tf.expand_dims(pos, axis=0)

    def call(
        self,
        inputs,
        **kwargs,
    ):
        # inputs shape [B, T, V]
        _, max_len, dmodel = shape_list(inputs)
        pe = self.encode(max_len * self.alpha + self.beta, dmodel)
        return tf.cast(pe, dtype=inputs.dtype)



L2 = tf.keras.regularizers.l2(1e-6)


class FFModule(tf.keras.layers.Layer):
    def __init__(
        self,
        input_dim,
        dropout=0.0,
        fc_factor=0.5,
        kernel_regularizer=L2,
        bias_regularizer=L2,
        name="ff_module",
        **kwargs,
    ):
        super(FFModule, self).__init__(name=name, **kwargs)
        self.fc_factor = fc_factor
        self.ln = tf.keras.layers.LayerNormalization(
            name=f"{name}_ln",
            gamma_regularizer=kernel_regularizer,
            beta_regularizer=bias_regularizer,
        )
        self.ffn1 = tf.keras.layers.Dense(
            4 * input_dim,
            name=f"{name}_dense_1",
            kernel_regularizer=kernel_regularizer,
            bias_regularizer=bias_regularizer,
        )
        self.swish = tf.keras.layers.Activation(tf.nn.swish, name=f"{name}_swish_activation")
        self.do1 = tf.keras.layers.Dropout(dropout, name=f"{name}_dropout_1")
        self.ffn2 = tf.keras.layers.Dense(
            input_dim,
            name=f"{name}_dense_2",
            kernel_regularizer=kernel_regularizer,
            bias_regularizer=bias_regularizer,
        )
        self.do2 = tf.keras.layers.Dropout(dropout, name=f"{name}_dropout_2")
        self.res_add = tfa.layers.StochasticDepth(name=f"{name}_add", survival_probability=0.4)
        self.supports_masking = True # RAJOUT

    def call(
        self,
        inputs,
        training=None,
        **kwargs,
    ):
        outputs = self.ln(inputs, training=training)
        outputs = self.ffn1(outputs, training=training)
        outputs = self.swish(outputs)
        outputs = self.do1(outputs, training=training)
        outputs = self.ffn2(outputs, training=training)
        outputs = self.do2(outputs, training=training)
        outputs = self.res_add([inputs, self.fc_factor * outputs])
        return outputs

    def get_config(self):
        conf = super(FFModule, self).get_config()
        conf.update({"fc_factor": self.fc_factor})
        conf.update(self.ln.get_config())
        conf.update(self.ffn1.get_config())
        conf.update(self.swish.get_config())
        conf.update(self.do1.get_config())
        conf.update(self.ffn2.get_config())
        conf.update(self.do2.get_config())
        conf.update(self.res_add.get_config())
        return conf


class MHSAModule(tf.keras.layers.Layer):
    def __init__(
        self,
        head_size,
        num_heads,
        dropout=0.0,
        mha_type="relmha",
        kernel_regularizer=L2,
        bias_regularizer=L2,
        name="mhsa_module",
        **kwargs,
    ):
        super(MHSAModule, self).__init__(name=name, **kwargs)
        self.ln = tf.keras.layers.LayerNormalization(
            name=f"{name}_ln",
            gamma_regularizer=kernel_regularizer,
            beta_regularizer=bias_regularizer,
        )
        if mha_type == "relmha":
            self.mha = RelPositionMultiHeadAttention(
                name=f"{name}_mhsa",
                head_size=head_size,
                num_heads=num_heads,
                kernel_regularizer=kernel_regularizer,
                bias_regularizer=bias_regularizer,
            )
        elif mha_type == "mha":
            self.mha = MultiHeadAttention(
                name=f"{name}_mhsa",
                head_size=head_size,
                num_heads=num_heads,
                kernel_regularizer=kernel_regularizer,
                bias_regularizer=bias_regularizer,
            )
        else:
            raise ValueError("mha_type must be either 'mha' or 'relmha'")
        self.do = tf.keras.layers.Dropout(dropout, name=f"{name}_dropout")
        self.res_add = tfa.layers.StochasticDepth(name=f"{name}_add", survival_probability=0.5)
        self.mha_type = mha_type
        self.supports_masking = True # RAJOUT

    def call(
        self,
        inputs,
        training=None,
        mask=None,
        attention_mask=None,
        **kwargs,
    ):
        inputs, pos = inputs  # pos is positional encoding
        outputs = self.ln(inputs, training=training)
        if self.mha_type == "relmha":
            outputs = self.mha([outputs, outputs, outputs, pos], training=training, mask=mask, attention_mask=attention_mask)
        else:
            outputs = outputs + pos
            outputs = self.mha([outputs, outputs, outputs], training=training, mask=mask, attention_mask=attention_mask)
        outputs = self.do(outputs, training=training)
        outputs = self.res_add([inputs, outputs])
        return outputs

    def get_config(self):
        conf = super(MHSAModule, self).get_config()
        conf.update({"mha_type": self.mha_type})
        conf.update(self.ln.get_config())
        conf.update(self.mha.get_config())
        conf.update(self.do.get_config())
        conf.update(self.res_add.get_config())
        return conf


class ConvModule(tf.keras.layers.Layer):
    def __init__(
        self,
        input_dim,
        kernel_size=32,
        dropout=0.0,
        depth_multiplier=1,
        kernel_regularizer=L2,
        bias_regularizer=L2,
        name="conv_module",
        **kwargs,
    ):
        super(ConvModule, self).__init__(name=name, **kwargs)
        self.ln = tf.keras.layers.LayerNormalization()
        self.pw_conv_1 = tf.keras.layers.Conv1D(
            filters=2 * input_dim,
            kernel_size=1,
            strides=1,
            padding="valid",
            name=f"{name}_pw_conv_1",
            kernel_regularizer=kernel_regularizer,
            bias_regularizer=bias_regularizer,
        )
        self.glu = GLU(name=f"{name}_glu")
        self.dw_conv = CausalDWConv1D(
            kernel_size=kernel_size,
            name=f"{name}_dw_conv",
            depth_multiplier = depth_multiplier,
        )
        self.bn = tf.keras.layers.BatchNormalization(
            name=f"{name}_bn",
            gamma_regularizer=kernel_regularizer,
            beta_regularizer=bias_regularizer,
        )
        self.swish = tf.keras.layers.Activation(
            tf.nn.swish,
            name=f"{name}_swish_activation",
        )
        self.pw_conv_2 = tf.keras.layers.Conv1D(
            filters=input_dim,
            kernel_size=1,
            strides=1,
            padding="valid",
            name=f"{name}_pw_conv_2",
            kernel_regularizer=kernel_regularizer,
            bias_regularizer=bias_regularizer,
        )
        self.do = tf.keras.layers.Dropout(dropout, name=f"{name}_dropout")
        self.res_add = tfa.layers.StochasticDepth(name=f"{name}_add", survival_probability=0.5)

    def call(
        self,
        inputs,
        training=None,
        **kwargs,
    ):
        outputs = self.ln(inputs, training=training)
        B, T, E = shape_list(outputs)
        outputs = tf.reshape(outputs, [B, T, E]) # [B, T, 1, E]
        outputs = self.pw_conv_1(outputs, training=training)
        outputs = self.glu(outputs)
        outputs = self.dw_conv(outputs, training=training)
        outputs = self.bn(outputs, training=training)
        outputs = self.swish(outputs)
        outputs = self.pw_conv_2(outputs, training=training)
        outputs = tf.reshape(outputs, [B, T, E]) #
        outputs = self.do(outputs, training=training)
        outputs = self.res_add([inputs, outputs])
        return outputs

    def get_config(self):
        conf = super(ConvModule, self).get_config()
        conf.update(self.ln.get_config())
        conf.update(self.pw_conv_1.get_config())
        conf.update(self.glu.get_config())
        conf.update(self.dw_conv.get_config())
        conf.update(self.bn.get_config())
        conf.update(self.swish.get_config())
        conf.update(self.pw_conv_2.get_config())
        conf.update(self.do.get_config())
        conf.update(self.res_add.get_config())
        return conf



class CausalDWConv1D(tf.keras.layers.Layer):
    def __init__(self,
        kernel_size=17,
        dilation_rate=1,
        use_bias=False,
        depthwise_initializer=tf.keras.initializers.glorot_uniform(seed=SEED),
        depth_multiplier = 1,
        name='', **kwargs):
        super().__init__(name=name,**kwargs)
        self.causal_pad = tf.keras.layers.ZeroPadding1D((dilation_rate*(kernel_size-1),0),name=name + '_pad')
        self.dw_conv = tf.keras.layers.DepthwiseConv1D(
                            kernel_size,
                            strides=1,
                            dilation_rate=dilation_rate,
                            depth_multiplier=depth_multiplier,
                            padding='valid',
                            use_bias=use_bias,
                            depthwise_initializer=depthwise_initializer,
                            name=name + '_dwconv')
        self.supports_masking = True

    def call(self, inputs):
        x = self.causal_pad(inputs)
        x = self.dw_conv(x)
        return x

class ConformerBlock(tf.keras.layers.Layer):
    def __init__(
        self,
        input_dim,
        dropout=0.0,
        fc_factor=0.5,
        head_size=36,
        num_heads=4,
        mha_type="relmha",
        kernel_size=32,
        depth_multiplier=1,
        kernel_regularizer=L2,
        bias_regularizer=L2,
        name="conformer_block",
        **kwargs,
    ):
        super(ConformerBlock, self).__init__(name=name, **kwargs)
        self.ffm1 = FFModule(
            input_dim=input_dim,
            dropout=dropout,
            fc_factor=fc_factor,
            name=f"{name}_ff_module_1",
            kernel_regularizer=kernel_regularizer,
            bias_regularizer=bias_regularizer,
        )
        self.mhsam = MHSAModule(
            mha_type=mha_type,
            head_size=head_size,
            num_heads=num_heads,
            dropout=dropout,
            name=f"{name}_mhsa_module",
            kernel_regularizer=kernel_regularizer,
            bias_regularizer=bias_regularizer,
        )
        self.convm = ConvModule(
            input_dim=input_dim,
            kernel_size=kernel_size,
            dropout=dropout,
            name=f"{name}_conv_module",
            depth_multiplier=depth_multiplier,
            kernel_regularizer=kernel_regularizer,
            bias_regularizer=bias_regularizer,
        )
        self.ffm2 = FFModule(
            input_dim=input_dim,
            dropout=dropout,
            fc_factor=fc_factor,
            name=f"{name}_ff_module_2",
            kernel_regularizer=kernel_regularizer,
            bias_regularizer=bias_regularizer,
        )
        self.ln = tf.keras.layers.LayerNormalization(
            name=f"{name}_ln",
            gamma_regularizer=kernel_regularizer,
            beta_regularizer=kernel_regularizer,
        )
        self.supports_masking = True # RAJOUT

    def call(
        self,
        inputs,
        training=None,
        mask=None,
        attention_mask=None,
        **kwargs,
    ):
        inputs, pos = inputs  # pos is positional encoding
        outputs = self.ffm1(inputs, training=training, **kwargs)
        outputs = self.mhsam([outputs, pos], training=training, mask=mask, attention_mask=attention_mask, **kwargs)
        outputs = self.convm(outputs, training=training, **kwargs)
        outputs = self.ffm2(outputs, training=training, **kwargs)
        outputs = self.ln(outputs, training=training)
        return outputs

    def get_config(self):
        conf = super(ConformerBlock, self).get_config()
        conf.update(self.ffm1.get_config())
        conf.update(self.mhsam.get_config())
        conf.update(self.convm.get_config())
        conf.update(self.ffm2.get_config())
        conf.update(self.ln.get_config())
        return conf


def get_attention_mask(x_inp, mask_value):
    padding_mask = tf.reduce_sum(x_inp, axis=-1)
    padding_mask = tf.cast(tf.math.equal(padding_mask, mask_value), tf.float32)
    padding_mask = 1 - padding_mask
    padding_mask = padding_mask[:, tf.newaxis, :]
    return padding_mask

In [None]:
def CTCLoss(labels, logits):
    label_length = tf.reduce_sum(tf.cast(labels != pad_token_idx, tf.int32), axis=-1)
    logit_length = tf.ones(tf.shape(logits)[0], dtype=tf.int32) * tf.shape(logits)[1]

    loss = classic_ctc_loss(
            labels=labels,
            logits=logits,
            label_length=label_length,
            logit_length=logit_length,
            blank_index=pad_token_idx,
        )

    loss = tf.reduce_mean(loss)
    return loss

In [None]:
# Config param

#pas mal proche de 9.0 avec do dense 75 kernek size 3 stoch 0.7 0.65 0.55

INPUT_SHAPE = [128, 276] #format d'entrée
dim = 256 # Embedding dimension
num_blocs = 12
dropout_cformer = 0.0 # Dropout applied in each module of conformer block, in different location in modules
num_heads = 8
head_size = dim // num_heads # head_size * num_heads should be equal to dim
depth_multiplier = 1 # didn't try to modify it but I think it will compile anyway and it's not necessary
kernel_size = 3 # kernel size of Conv module, specially the dephwiseConv1d cause pointwise is 1
# I don't know why I'm writting it in english, we're all french dudes in this team

In [None]:
from tensorflow.keras.layers import SpatialDropout1D

def get_model(dim = dim, num_blocs = num_blocs):
    with strategy.scope():
        inp = tf.keras.Input(INPUT_SHAPE, name="input")

#         mask = tf.keras.layers.Masking(mask_value=-1000.0, input_shape=INPUT_SHAPE)

#         x = mask(inp)
        x = tf.keras.layers.Dense(dim, use_bias=False,name='stem_conv')(inp)
        pe = PositionalEncoding()
        pex = pe(x)

        x = SpatialDropout1D(0.2, name='spatial_dropout_pe')(x)  # DO spatial
#         attention_mask = get_attention_mask(inp,mask_value=-1000.0)

        conf_blocks = []
        for i in range(num_blocs):
            name = f'Conf_block_{i}'
            conf_block= ConformerBlock(input_dim=dim,
                                       head_size=head_size,
                                       dropout=dropout_cformer,
                                       num_heads=num_heads,
                                       depth_multiplier=depth_multiplier,
                                       kernel_size=kernel_size,
                                       name=name)
            conf_blocks.append(conf_block)

        for cblock in conf_blocks:
            x = cblock([x, pex])

        #x = SpatialDropout1D(0.6, name='spatial_dropout_pe')(x)  # DO spatial
        x = tf.keras.layers.Dense(dim*2,activation="relu",name='top_conv')(x)
        x = tf.keras.layers.Dropout(0.7)(x)
        x = tf.keras.layers.Dense(len(char_to_num))(x)

        model = tf.keras.Model(inp, x)
#         # Adversarial Training
#         adv_config = nsl.configs.make_adv_reg_config(multiplier=0.4, adv_step_size=0.05, adv_grad_norm = 'infinity')
#         adv_model = nsl.keras.AdversarialRegularization(model,
#                                                         label_keys=['label'],
#                                                         adv_config=adv_config)

        loss = CTCLoss

        # Adam Optimizer
        optimizer = tfa.optimizers.RectifiedAdam(sma_threshold=4)
        optimizer = tfa.optimizers.Lookahead(optimizer, sync_period=5)

        model.compile(loss=loss, optimizer=optimizer)
      #  adv_model.compile(loss=loss, optimizer=optimizer)

    return model#, adv_model

tf.keras.backend.clear_session()
# with strategy.scope():
#     seed_everything()
#     base_model, awp_model = get_model()
#     # base_model(batch["input"])
#     # base_model.summary()
#     awp_model(batch)
#     awp_model.summary()
model = get_model()
model(batch[0])
model.summary()

In [None]:
def num_to_char_fn(y):
    return [num_to_char.get(x, "") for x in y]

@tf.function()
def decode_phrase(pred):
    x = tf.argmax(pred, axis=1)
    diff = tf.not_equal(x[:-1], x[1:])
    adjacent_indices = tf.where(diff)[:, 0]
    x = tf.gather(x, adjacent_indices)
    mask = x != pad_token_idx
    x = tf.boolean_mask(x, mask, axis=0)
    return x

# A utility function to decode the output of the network
def decode_batch_predictions(pred):
    output_text = []
    for result in pred:
        result = "".join(num_to_char_fn(decode_phrase(result).numpy()))
        output_text.append(result)
    return output_text

In [None]:

# A callback class to output a few transcriptions during training
class CallbackEval(tf.keras.callbacks.Callback):
    """Displays a batch of outputs after every epoch."""

    def __init__(self, dataset):
        super().__init__()
        self.dataset = dataset

    def on_epoch_end(self, epoch: int, logs=None):
        model.save_weights("model.h5")
        predictions = []
        targets = []
        for batch in self.dataset:
            X, y = batch
            batch_predictions = model(X)
            batch_predictions = decode_batch_predictions(batch_predictions)
            predictions.extend(batch_predictions)
            for label in y:
                label = "".join(num_to_char_fn(label.numpy()))
                targets.append(label)
        print("-" * 100)
        # for i in np.random.randint(0, len(predictions), 2):
        for i in range(32):
            print(f"Target    : {targets[i]}")
            print(f"Prediction: {predictions[i]}, len: {len(predictions[i])}")
            print("-" * 100)

# Callback function to check transcription on the val set.
validation_callback = CallbackEval(val_dataset.take(1))

In [None]:
N_EPOCHS = 300
N_WARMUP_EPOCHS = 10
LR_MAX = 1e-4 * 8
WD_RATIO = 0.05
WARMUP_METHOD = "exp"

### This is for calculating the validation set' Levenshtein distance during training

In [None]:
val_set = [x for x in val_dataset]

with open ("/kaggle/input/asl-fingerspelling/character_to_prediction_index.json", "r") as f:
    character_map = json.load(f)
rev_character_map = {j:i for i,j in character_map.items()}

class val_lev_callback(tf.keras.callbacks.Callback):
    def __init__(self):
        super().__init__()
    def on_epoch_end(self, epoch: int, logs=None):
        calculate_val_lev()

def calculate_val_lev():
    preds = []
    targets = []
    scores = []
    for batch_idx in range(len(val_set)):
        preds_batch = model.predict(val_set[batch_idx][0], verbose = 0)
        targets_batch = val_set[batch_idx][1]
        for pred_idx in range(len(preds_batch)):
            preds.append("".join([rev_character_map.get(s, "") for s in decode_phrase(preds_batch[pred_idx]).numpy()]))
            targets.append("".join([rev_character_map.get(s, "") for s in targets_batch[pred_idx].numpy()]))

    N = [len(phrase) for phrase in targets]
    lev_dist = [lev.distance(preds[i], targets[i]) for i in range(len(targets))]
    print('Lev distance: '+str((np.sum(N) - np.sum(lev_dist))/np.sum(N)))

In [None]:
def lrfn(current_step, num_warmup_steps, lr_max, num_cycles=0.50, num_training_steps=N_EPOCHS):

    if current_step < num_warmup_steps:
        if WARMUP_METHOD == 'log':
            return lr_max * 0.10 ** (num_warmup_steps - current_step)
        else:
            return lr_max * 2 ** -(num_warmup_steps - current_step)
    else:
        progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))

        return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))) * lr_max

def plot_lr_schedule(lr_schedule, epochs):
    fig = plt.figure(figsize=(20, 10))
    plt.plot([None] + lr_schedule + [None])
    # X Labels
    x = np.arange(1, epochs + 1)
    x_axis_labels = [i if epochs <= 40 or i % 5 == 0 or i == 1 else None for i in range(1, epochs + 1)]
    plt.xlim([1, epochs])
    plt.xticks(x, x_axis_labels) # set tick step to 1 and let x axis start at 1

    # Increase y-limit for better readability
    plt.ylim([0, max(lr_schedule) * 1.1])

    # Title
    schedule_info = f'start: {lr_schedule[0]:.1E}, max: {max(lr_schedule):.1E}, final: {lr_schedule[-1]:.1E}'
    plt.title(f'Step Learning Rate Schedule, {schedule_info}', size=18, pad=12)

    # Plot Learning Rates
    for x, val in enumerate(lr_schedule):
        if epochs <= 40 or x % 5 == 0 or x is epochs - 1:
            if x < len(lr_schedule) - 1:
                if lr_schedule[x - 1] < val:
                    ha = 'right'
                else:
                    ha = 'left'
            elif x == 0:
                ha = 'right'
            else:
                ha = 'left'
            plt.plot(x + 1, val, 'o', color='black');
            offset_y = (max(lr_schedule) - min(lr_schedule)) * 0.02
            plt.annotate(f'{val:.1E}', xy=(x + 1, val + offset_y), size=12, ha=ha)

    plt.xlabel('Epoch', size=16, labelpad=5)
    plt.ylabel('Learning Rate', size=16, labelpad=5)
    plt.grid()
    plt.show()

# Learning rate for encoder
LR_SCHEDULE = [lrfn(step, num_warmup_steps=N_WARMUP_EPOCHS, lr_max=LR_MAX, num_cycles=0.50) for step in range(N_EPOCHS)]
# Plot Learning Rate Schedule
plot_lr_schedule(LR_SCHEDULE, epochs=N_EPOCHS)
# Learning Rate Callback
lr_callback = tf.keras.callbacks.LearningRateScheduler(lambda step: LR_SCHEDULE[step], verbose=0)

# Custom callback to update weight decay with learning rate
class WeightDecayCallback(tf.keras.callbacks.Callback):
    def __init__(self, wd_ratio=WD_RATIO):
        self.step_counter = 0
        self.wd_ratio = wd_ratio

    def on_epoch_begin(self, epoch, logs=None):
        self.model.optimizer.weight_decay = self.model.optimizer.learning_rate * self.wd_ratio # model w/o self
        print(f'learning rate: {self.model.optimizer.learning_rate.numpy():.2e}, weight decay: {self.model.optimizer.weight_decay.numpy():.2e}')

class save_model_callback(tf.keras.callbacks.Callback):
    def __init__(self):
        super().__init__()
    def on_epoch_end(self, epoch: int, logs=None):
        if epoch > 75:
            val_loss = logs['val_loss']  # Get the validation loss from logs
            self.model.save_weights(f"model_epoch_{epoch}_val_loss_{val_loss:.4f}.h5")

In [None]:
early_stop_callbacks = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=10)

In [None]:

history = model.fit(
    train_dataset,
    validation_data=val_dataset,
    epochs=N_EPOCHS,
    verbose = 1,
    callbacks=[
        save_model_callback(),
        lr_callback,
        WeightDecayCallback(),
        val_lev_callback(),
       # early_stop_callbacks
    ]
)

In [None]:
# awp.base_model.save_weights(f"model_3.h5")

In [None]:
#DO-0-augV4-0.8#j'ai pas eu la meilleur epoch il a relancer mais y avait 10.8 au lieu du 11.08