In [None]:
import os
import re
import string
import random
from glob import glob
from argparse import Namespace

import tensorflow as tf
print(tf.__version__)
from tensorflow.keras import layers
from tensorflow.keras import models

# Detect TPU, return appropriate distribution strategy
try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver() 
    print('Running on TPU ', tpu.master())
except ValueError:
    tpu = None

if tpu:
    tf.config.experimental_connect_to_cluster(tpu)
    tf.tpu.experimental.initialize_tpu_system(tpu)
    strategy = tf.distribute.experimental.TPUStrategy(tpu)
else:
    strategy = tf.distribute.get_strategy() 

print("REPLICAS: ", strategy.num_replicas_in_sync)

In [None]:
import wandb
print(wandb.__version__)

from wandb.keras import WandbMetricsLogger
from wandb.keras import WandbModelCheckpoint

In [None]:
def id_generator(size=6, chars=string.ascii_uppercase + string.digits):
    return ''.join(random.choice(chars) for _ in range(size))

random_id = id_generator(size=8)
print('Experiment Id: ', random_id)

In [None]:
configs = Namespace(
    num_frames = 32,
    batch_size = 128 if strategy.num_replicas_in_sync==1 else 16*strategy.num_replicas_in_sync,
    experiment_id = random_id,
    epochs = 30,
    resizing_interpolation = "nearest",
    learning_rate = 1e-4,
    num_steps = 1.0,
    lips_patch_size = 8,
    rh_patch_size = 7,
    lh_patch_size = 7,
    embed_dim = 128,
    num_transformer_blocks=2,
    num_heads = 4,
    layer_norm_eps = 1e-6
)


LIP = [
    61, 185, 40, 39, 37, 0, 267, 269, 270, 409,
    291, 146, 91, 181, 84, 17, 314, 405, 321, 375,
    78, 191, 80, 81, 82, 13, 312, 311, 310, 415,
    95, 88, 178, 87, 14, 317, 402, 318, 324, 308,
]

RIGHT_EYE = [
    246, 161, 160, 159, 158, 157, 173,
    33, 7, 163, 144, 145, 153, 154, 155, 133,
    247, 30, 29, 27, 28, 56, 190,
    130, 25, 110, 24, 23, 22, 26, 112, 243,
    113, 225, 224, 223, 222, 221, 189,
    226, 31, 228, 229, 230, 231, 232, 233, 244,
    143, 111, 117, 118, 119, 120, 121, 128, 245,
]

LEFT_EYE = [
    466, 387, 386, 385, 384, 398,
    263, 249, 390, 373, 374, 380, 381, 382, 362,
    467, 260, 259, 257, 258, 286, 414,
    359, 255, 339, 254, 253, 252, 256, 341, 463,
    342, 445, 444, 443, 442, 441, 413,
    446, 261, 448, 449, 450, 451, 452, 453, 464,
    372, 340, 346, 347, 348, 349, 350, 357, 465,
]

In [None]:
run = wandb.init(
    project="kaggle-asl-tubelet",
    config=configs,
    job_type="train",
)

In [None]:
def natural_keys(text):
    ""
    def atoi(text):
        return int(text) if text.isdigit() else text
    
    return [atoi(c) for c in re.split(r'(\d+)', text)]

tfrecords = glob("../data/tfrecords-participants/*.tfrec")
tfrecords = sorted(tfrecords, key=natural_keys)

In [None]:
train_tfrecords, valid_tfrecords = tfrecords[:18], tfrecords[18:]
print(len(train_tfrecords), len(valid_tfrecords))

def parse_sequence(serialized_sequence):
    return tf.io.parse_tensor(
        serialized_sequence,
        out_type=tf.float32,
    )


def parse_tfrecord_fn(example):
    feature_description = {
        "n_frames": tf.io.FixedLenFeature([], tf.float32),
        "frames": tf.io.FixedLenFeature([], tf.string),
        "label": tf.io.FixedLenFeature([], tf.int64),
    }
    
    return tf.io.parse_single_example(example, feature_description)


@tf.function
def preprocess_frames(frames):
    """
    In this preprocessing function:
    - Fill NaN values to 0.
    - Use `tf.image.resize` to interpolate.
    """
    frames = tf.image.resize(
        frames, (configs.num_frames, 543), method=configs.resizing_interpolation
    )

    return frames


def normalize_frames(frames):
    """
    Normalize each video
    """
    not_nan_frames = frames[~tf.math.is_nan(frames)]

    frames -= tf.math.reduce_mean(not_nan_frames, axis=0, keepdims=True)
    frames /= tf.math.reduce_std(not_nan_frames, axis=0, keepdims=True)

    frames = tf.where(tf.math.is_finite(frames), frames, tf.zeros_like(frames))
    
    return frames


def parse_data(example):
    # Parse Frames
    n_frames = example["n_frames"]
    frames = tf.reshape(parse_sequence(example["frames"]), shape=(n_frames, 543, 3))
    frames = preprocess_frames(frames)
    frames = normalize_frames(frames)

    # Parse Labels
    label = tf.one_hot(example["label"], depth=250)

    return frames, label

AUTOTUNE = tf.data.AUTOTUNE

train_ds = tf.data.TFRecordDataset(train_tfrecords)
valid_ds = tf.data.TFRecordDataset(valid_tfrecords)

trainloader = (
    train_ds
    .map(parse_tfrecord_fn, num_parallel_calls=AUTOTUNE)
    .shuffle(1024)
    .map(parse_data, num_parallel_calls=AUTOTUNE)
    .batch(configs.batch_size)
    .prefetch(AUTOTUNE)
)

validloader = (
    valid_ds
    .map(parse_tfrecord_fn, num_parallel_calls=AUTOTUNE)
    .map(parse_data, num_parallel_calls=AUTOTUNE)
    .batch(configs.batch_size)
    .prefetch(AUTOTUNE)
)

In [None]:
class TubeletEmbedding1(layers.Layer):
    def __init__(self, embed_dim, patch_size, **kwargs):
        super().__init__(**kwargs)
        self.projection = layers.Conv3D(
            filters=embed_dim,
            kernel_size=(patch_size,patch_size, 1),
            strides=(patch_size, patch_size, patch_size),
            padding="VALID",
        )
        self.flatten = layers.Reshape(target_shape=(-1, embed_dim))

    def call(self, videos):
        projected_patches = self.projection(videos)
        flattened_patches = self.flatten(projected_patches)
        return flattened_patches


class TubeletEmbedding2(layers.Layer):
    def __init__(self, embed_dim, patch_sizes, **kwargs):
        super().__init__(**kwargs)
        assert len(patch_sizes) == 3 # lips, right_hand, left_hand
        lips_patch_size, rh_patch_size, lh_patch_size = patch_sizes[0], patch_sizes[1], patch_sizes[2]

        self.lips_projection = layers.Conv3D(
            filters=embed_dim,
            kernel_size=(lips_patch_size,lips_patch_size, 1),
            strides=lips_patch_size,
            padding="VALID",
        )
        self.rh_projection = layers.Conv3D(
            filters=embed_dim,
            kernel_size=(rh_patch_size,rh_patch_size, 1),
            strides=rh_patch_size,
            padding="VALID",
        )
        self.lh_projection = layers.Conv3D(
            filters=embed_dim,
            kernel_size=(lh_patch_size,lh_patch_size, 1),
            strides=lh_patch_size,
            padding="VALID",
        )
        
        self.concat = layers.Concatenate(axis=2)
        self.flatten = layers.Reshape(target_shape=(-1, embed_dim))
        self.layer_norm = layers.LayerNormalization(epsilon=1e-6)

    def call(self, inputs):
        lips, rh, lh = inputs[:,:,0:40,:, tf.newaxis], inputs[:,:,40:61,:, tf.newaxis], inputs[:,:,61:82,:, tf.newaxis]
        lips_projected = self.lips_projection(lips)
        rh_projected = self.rh_projection(rh)
        lh_projected = self.lh_projection(lh)
        projected_patches = self.concat([lips_projected, rh_projected, rh_projected])
        flattened_patches = self.flatten(projected_patches)
        flattened_patched = self.layer_norm(flattened_patches)
        return flattened_patches
    
    
class PositionalEncoder(layers.Layer): # Can I use sine+cosine positional encoder? or rotatory encoder?
    def __init__(self, embed_dim, **kwargs):
        super().__init__(**kwargs)
        self.embed_dim = embed_dim

    def build(self, input_shape):
        _, num_tokens, _ = input_shape
        self.position_embedding = layers.Embedding(
            input_dim=num_tokens, output_dim=self.embed_dim
        )
        self.positions = tf.range(start=0, limit=num_tokens, delta=1)

    def call(self, encoded_tokens):
        # Encode the positions and add it to the encoded tokens
        encoded_positions = self.position_embedding(self.positions)
        encoded_tokens = encoded_tokens + encoded_positions
        return encoded_tokens

In [None]:
class BaseAttention(tf.keras.layers.Layer):
  def __init__(self, **kwargs):
    super().__init__()
    self.mha = tf.keras.layers.MultiHeadAttention(**kwargs)
    self.layernorm = tf.keras.layers.LayerNormalization()
    self.add = tf.keras.layers.Add()
    
    
class GlobalSelfAttention(BaseAttention):
  def call(self, x):
    attn_output = self.mha(
        query=x,
        value=x,
        key=x)
    x = self.add([x, attn_output])
    x = self.layernorm(x)
    return x


class FeedForward(tf.keras.layers.Layer):
  def __init__(self, d_model, dff, dropout_rate=0.1):
    super().__init__()
    self.seq = tf.keras.Sequential([
      tf.keras.layers.Dense(dff, activation='relu'),
      tf.keras.layers.Dense(d_model),
      tf.keras.layers.Dropout(dropout_rate)
    ])
    self.add = tf.keras.layers.Add()
    self.layer_norm = tf.keras.layers.LayerNormalization()

  def call(self, x):
    x = self.add([x, self.seq(x)])
    x = self.layer_norm(x) 
    return x

In [None]:
class EncoderLayer(tf.keras.layers.Layer):
  def __init__(self,*, d_model, num_heads, dff, dropout_rate=0.1):
    super().__init__()

    self.self_attention = GlobalSelfAttention(
        num_heads=num_heads,
        key_dim=d_model,
        dropout=dropout_rate)

    self.ffn = FeedForward(d_model, dff)

  def call(self, x):
    x = self.self_attention(x)
    x = self.ffn(x)
    return x


class ViViTEncoder(tf.keras.layers.Layer):
  def __init__(self, num_layers, d_model, num_heads, dropout_rate=0.1):
    super().__init__()

    self.d_model = d_model
    self.num_layers = num_layers

    self.pos_embedding = PositionalEncoder(embed_dim=d_model)

    self.enc_layers = [
        EncoderLayer(d_model=d_model,
                     num_heads=num_heads,
                     dff=4*d_model,
                     dropout_rate=dropout_rate)
        for _ in range(num_layers)]
    self.dropout = tf.keras.layers.Dropout(dropout_rate)

  def call(self, x):
    # `x` is token-IDs shape: (batch, seq_len)
    x = self.pos_embedding(x)  # Shape `(batch_size, seq_len, d_model)`.
    
    # Add dropout.
    x = self.dropout(x)

    for i in range(self.num_layers):
      x = self.enc_layers[i](x)

    return x  # Shape `(batch_size, seq_len, d_model)`.

In [None]:
def create_vivit_classifier(
    tubelet_embedder,
    positional_encoder,
    num_layers=configs.num_transformer_blocks,
    num_heads=configs.num_heads,
    embed_dim=configs.embed_dim,
    layer_norm_eps=configs.layer_norm_eps,
):
    # Get the input layer
    inputs = layers.Input(shape=(32, 543, 3))

    # Get lips, right hand and left hand
    lip_inputs = tf.gather(inputs, indices=LIP, axis=2)
    left_hand_inputs = inputs[:, :, 468:489, :]
    right_hand_inputs = inputs[:, :, 522:, :]

    landmarks = tf.keras.layers.Concatenate(axis=2)(
        [lip_inputs, right_hand_inputs, left_hand_inputs]
    )
    
    # Create patches.
    patches = tubelet_embedder(landmarks)
        
    # Apply Encoder
    encoder_output = ViViTEncoder(num_layers, embed_dim, num_heads)(patches)
    
    # Layer normalization and Global average pooling.
    representation = layers.LayerNormalization(epsilon=layer_norm_eps)(encoder_output)
    representation = representation[:,0]

    # Classify outputs.
    outputs = layers.Dense(units=250, activation="softmax")(representation)

    # Create the Keras model.
    model = tf.keras.Model(inputs=inputs, outputs=outputs)
    return model

In [None]:
tf.keras.backend.clear_session()

with strategy.scope():
    tubelet_embedder = TubeletEmbedding2(
        configs.embed_dim,
        patch_sizes=(configs.lips_patch_size, configs.rh_patch_size, configs.lh_patch_size)
    )
    positional_encoder = PositionalEncoder(configs.embed_dim)
    model = create_vivit_classifier(tubelet_embedder, positional_encoder)

total_steps = 585*configs.epochs
decay_steps = total_steps*configs.num_steps

# cosine_decay_scheduler = tf.keras.optimizers.schedules.CosineDecay(
#     initial_learning_rate = configs.learning_rate,
#     decay_steps = decay_steps,
#     alpha=0.1
# )

model.compile(
    tf.keras.optimizers.Adam(configs.learning_rate),
    "binary_crossentropy",
    metrics=["acc"]
)

model.summary(expand_nested=True)

In [None]:
tf.keras.utils.plot_model(model, show_shapes=True, show_layer_names=True, expand_nested=True)

In [None]:
earlystopper = tf.keras.callbacks.EarlyStopping(
    monitor="val_loss",
    patience=7,
    verbose=0,
    mode="auto",
    restore_best_weights=True,
)

# checkpoint_options = tf.saved_model.SaveOptions(experimental_io_device="/job:localhost")

callbacks = [
#     earlystopper,
#     WandbMetricsLogger(log_freq=2),
#     WandbModelCheckpoint(
#         filepath=f"model-{configs.experiment_id}",
#         save_best_only=True,
#         options=checkpoint_options,
#     ),
]

model.fit(
    trainloader,
    epochs=configs.epochs,
    validation_data=validloader,
    callbacks=callbacks,
)

In [None]:
eval_loss, eval_acc = model.evaluate(validloader)
# wandb.log({"eval_loss": eval_loss, "eval_acc": eval_acc})

In [None]:
wandb.finish()