In [None]:
# imports for model

import keras_hub
import pathlib
import random

import keras
from keras import ops

import tensorflow.data as tf_data
import tensorflow.strings as tf_string

In [None]:
# imports for pre-parsing

from pyparsing import Word, alphas, nums
import pyparsing as pp
import re

In [None]:
# handshapes

handshapelist = [
  '1',
  '3',
  '4',
  '5',
  '6',
  '7',
  '8',
  '9',
  '10',
  '25',
  'A',
  'B',
  'C',
  'D',
  'E',
  'F',
  'G',
  'H-U',
  'I',
  'K',
  'L',
  'M',
  'N',
  'O',
  'R',
  'S',
  'T',
  'U',
  'V',
  'W',
  'X',
  'Y',
  'C-L',
  'U-L',
  'B-L',
  'P-K',
  'Q-G',
  'L-X',
  'I-L-Y',
  '5-C',
  '5-C-L',
  '5-C-tt',
  'alt-M',
  'alt-N',
  'alt-P',
  'B-xd',
  'baby-O',
  'bent-1',
  'bent-B',
  'bent-B-L',
  'bent-horns',
  'bent-L',
  'bent-M',
  'bent-N',
  'bent-U',
  'cocked-8',
  'cocked-F',
  'cocked-S',
  'crvd-5',
  'crvd-B',
  'crvd-flat-B',
  'crvd-L',
  'crvd-sprd-B',
  'crvd-U',
  'crvd-V',
  'fanned-flat-O',
  'flat-B',
  'flat-C',
  'flat-F',
  'flat-G',
  'flat-O-2',
  'flat-O',
  'full-M',
  'horns',
  'loose-E',
  'O-2-horns',
  'open-8',
  'open-F',
  'sml-C-3',
  'tight-C-2',
  'tight-C',
  'X-over-thumb'
]

In [None]:
# regex rules

number_regexp = r"""
    [0-9]+                      # one or more digits
    (?:\.[0-9]+)?               # optional decimal
    s?                          # optional trailing 's'
"""

alpha_regexp = r"""
    (?!                         # negative lookahead for:
        (?:THUMB-)?             # optional THUMB-
        (?:IX-|POSS-|SELF-)     # followed by IX-, POSS-, SELF-
    )
    [A-Z0-9]                    # starts with uppercase or digit
    (?:                         # optionally followed by:
        [A-Z0-9'-]*             #   more letters, digits, hyphen or apostrophe
        [A-Z0-9]                #   ends on a letter or digit
    )?
    (?:                         # optionally followed by:
        \.                      #   a period
      |                         # or
        :[0-9]                  #   colon-digit (e.g., :2)
    )?
"""

lookahead_regexp = r"""
    (?:
        (?![a-z])               # not followed by lowercase letter
      | (?=wg)                  # unless it's specifically "wg"
    )
"""

word_all_regexp = r"""
    (?: {number_regexp} | {alpha_regexp} )
    {lookalookahead_regexphead}
"""

In [None]:
# conventions kept for parsing

cl_prefix = pp.one_of(["CL", "DCL", "LCL", "SCL", "BCL", "BPCL", "PCL", "ICL"])
ns_prefix = pp.Literal("ns")
fs_prefix = pp.Literal("fs")
lex_exceptions = pp.one_of(["part", "WHAT"])
aspect_text = pp.Literal("aspect")
index_core_ix = pp.Literal("IX")
other_index_core = pp.one_of(["POSS", "SELF"])
handshape = pp.one_of(handshapelist)
person = pp.one_of(["1p", "2p", "3p"])
dash = pp.Literal("-")
arc = pp.Literal("arc") 
loc = pp.Literal("loc")
pl = pp.Literal("pl")
compound = pp.Literal("+")
hashtag = pp.Literal("#")
choice = pp.Literal("/")
sym = pp.Literal(">")
par1 = pp.Literal("(")
par2 = pp.Literal(")")
contraction = pp.Literal("^")
colon = pp.Literal(":")
omit_quote = pp.Literal("xx")
period = pp.Literal(".")
alphas = pp.Word(alphas, max=1)
nums = pp.Word(nums, max=1)
word = pp.Regex(word_all_regexp, flags=re.X)

In [None]:
# grammar rules

full_grammar = pp.OneOrMore(
    cl_prefix |               # classifiers like CL, DCL, etc.
    ns_prefix |               # non-specific ns
    fs_prefix |               # fingerspelling fs
    index_core_ix |           # IX
    other_index_core |        # POSS, SELF
    person |                  # 1p, 2p, 3p
    lex_exceptions |          # part, WHAT
    aspect_text |             # aspect
    arc |                     # arc
    loc |                     # loc
    pl |                      # pl
    handshape |               # handshapes like B, 1, 5, etc.
    compound |                # +
    hashtag |                 # #
    choice |                  # /
    sym |                     # >
    contraction |             # ^
    colon |                   # :
    dash |                    # -
    par1 | par2 |             # ( and )
    omit_quote |              # xx
    period |                  # .
    word |
    nums |                    # numbers last resort
    alphas                     # fallback LAST
)

In [None]:
# testing grammar parsing

trial = full_grammar.parse_string("SCL:1xx", parse_all=True).asList()
trial2 = full_grammar.parse_string("IX-1p-pl-2 WORK LANDSCAPE fs-LANDSCAPING IX-1p 5xx", parse_all=True).as_list()

print(trial)
print(trial2)

In [None]:
# tokenize based on predefined grammar rules

def custom_tokenize(text):
    try:
        return full_grammar.parse_string(text, parse_all=True).as_list()
    except pp.ParseException as pe:
        print(f"Failed to parse: {pe}")
        return []

In [None]:
# testing custom_tokenize

trial = custom_tokenize("SCL:1xx")
trial2 = custom_tokenize("IX-1p-pl-2 WORK LANDSCAPE fs-LANDSCAPING IX-1p 5xx")

print(trial)
print(trial2)

In [None]:
# model parameters / hyperparameters

BATCH_SIZE = 64
EPOCHS = 30  # This should be at least 10 for convergence
MAX_SEQUENCE_LENGTH = 50
ENG_VOCAB_SIZE = 5056
ASL_VOCAB_SIZE = 2283
num_samples = 1400

EMBED_DIM = 128
INTERMEDIATE_DIM =1024
NUM_HEADS = 8
data_path = #change to sent_pairs.txt

In [None]:
# generate files of
    # 1) eng-asl sentence pairs
    # 2) list of unique english vocab
    # 3) list of unique asl vocab (main glosses only)
text_pairs = [[]]
index = 0

input_tokens = set()
target_tokens = set()

with open(data_path, "r", encoding="utf-8") as f:
    lines = f.read().split("\n")
   
tokens_used = [] 
for line in lines[: min(num_samples, len(lines) - 1)]:
    input_text, target_text = line.split("\t")
    text_pairs[index].append(input_text)
    text_pairs[index].append(target_text)
    index += 1
    for token in input_text:
        if token not in input_tokens:
            input_tokens.add(token)
            
            
for pair in text_pairs:
    for asl in pair[1]:
        sent_tokens = custom_tokenize(asl)
        if sent_tokens not in target_tokens:
            target_tokens.add(sent_tokens)


input_tokens = sorted(list(input_tokens))
target_tokens = sorted(list(target_tokens))

print("input_tokens:", input_tokens)
print("output_tokens", target_tokens)
num_encoder_tokens = len(input_tokens)
num_decoder_tokens = len(target_tokens)
print("num_eng_tokens", num_encoder_tokens)
print("num_asl_tokens", num_decoder_tokens)
print(target_tokens)

In [None]:
# glimpse pairs

for _ in range(5):
    print(random.choice(text_pairs))

In [None]:
# split data

random.shuffle(text_pairs)
num_val_samples = int(0.15 * len(text_pairs))
num_train_samples = len(text_pairs) - 2 * num_val_samples
train_pairs = text_pairs[:num_train_samples]
val_pairs = text_pairs[num_train_samples : num_train_samples + num_val_samples]
test_pairs = text_pairs[num_train_samples + num_val_samples :]

print(f"{len(text_pairs)} total pairs")
print(f"{len(train_pairs)} training pairs")
print(f"{len(val_pairs)} validation pairs")
print(f"{len(test_pairs)} test pairs")

In [None]:
# define train_word_piece

def train_word_piece(text_samples, vocab_size, reserved_tokens):
    word_piece_ds = tf_data.Dataset.from_tensor_slices(text_samples)
    vocab = keras_hub.tokenizers.compute_word_piece_vocabulary(
        word_piece_ds.batch(1000).prefetch(2),
        vocabulary_size=vocab_size,
        reserved_tokens=reserved_tokens,
    )
    return vocab

In [None]:
reserved_tokens = ["[PAD]", "[UNK]", "[START]", "[END]"]

eng_samples = [text_pair[0] for text_pair in train_pairs]
eng_vocab = train_word_piece(eng_samples, ENG_VOCAB_SIZE, reserved_tokens)

asl_samples = [text_pair[1] for text_pair in train_pairs]
asl_vocab = train_word_piece(asl_samples, ASL_VOCAB_SIZE, reserved_tokens)

In [None]:
print("English Tokens: ", eng_vocab[100:110])
print("ASL Tokens: ", asl_vocab[100:110])

In [None]:
eng_tokenizer = keras_hub.tokenizers.WordPieceTokenizer(
    vocabulary=eng_vocab, lowercase=False
)
asl_tokenizer = keras_hub.tokenizers.WordPieceTokenizer(
    vocabulary=asl_vocab, lowercase=False, split=False
)

In [None]:
eng_input_ex = text_pairs[0][0]
eng_tokens_ex = eng_tokenizer.tokenize(eng_input_ex)
print("English sentence: ", eng_input_ex)
print("Tokens: ", eng_tokens_ex)
print(
    "Recovered text after detokenizing: ",
    eng_tokenizer.detokenize(eng_tokens_ex),
)

print()

asl_input_ex = text_pairs[0][1]
asl_tokens_ex = asl_tokenizer.tokenize(asl_input_ex)
print("ASL Gloss sentence: ", asl_input_ex)
print("Tokens: ", asl_tokens_ex)
print(
    "Recovered text after detokenizing: ",
    asl_tokenizer.detokenize(asl_tokens_ex),
)

In [None]:
def preprocess_batch(eng, asl):
    batch_size = ops.shape(asl)[0]

    eng = eng_tokenizer(eng)
    asl = asl_tokenizer(asl)

    # Pad `eng` to `MAX_SEQUENCE_LENGTH`.
    eng_start_end_packer = keras_hub.layers.StartEndPacker(
        sequence_length=MAX_SEQUENCE_LENGTH,
        pad_value=eng_tokenizer.token_to_id("[PAD]"),
    )
    eng = eng_start_end_packer(eng)

    # Add special tokens (`"[START]"` and `"[END]"`) to `asl` and pad it as well.
    asl_start_end_packer = keras_hub.layers.StartEndPacker(
        sequence_length=MAX_SEQUENCE_LENGTH + 1,
        start_value=asl_tokenizer.token_to_id("[START]"),
        end_value=asl_tokenizer.token_to_id("[END]"),
        pad_value=asl_tokenizer.token_to_id("[PAD]"),
    )
    asl = asl_start_end_packer(asl)

    return (
        {
            "encoder_inputs": eng,
            "decoder_inputs": asl[:, :-1],
        },
        asl[:, 1:],
    )


def make_dataset(pairs):
    eng_texts, asl_texts = zip(*pairs)
    eng_texts = list(eng_texts)
    asl_texts = list(asl_texts)
    dataset = tf_data.Dataset.from_tensor_slices((eng_texts, asl_texts))
    dataset = dataset.batch(BATCH_SIZE)
    dataset = dataset.map(preprocess_batch, num_parallel_calls=tf_data.AUTOTUNE)
    return dataset.shuffle(2048).prefetch(16).cache()


train_ds = make_dataset(train_pairs)
val_ds = make_dataset(val_pairs)

In [None]:
for inputs, targets in train_ds.take(1):
    print(f'inputs["encoder_inputs"].shape: {inputs["encoder_inputs"].shape}')
    print(f'inputs["decoder_inputs"].shape: {inputs["decoder_inputs"].shape}')
    print(f"targets.shape: {targets.shape}")

In [None]:
# Encoder
encoder_inputs = keras.Input(shape=(None,), name="encoder_inputs")

x = keras_hub.layers.TokenAndPositionEmbedding(
    vocabulary_size=ENG_VOCAB_SIZE,
    sequence_length=MAX_SEQUENCE_LENGTH,
    embedding_dim=EMBED_DIM,
)(encoder_inputs)

encoder_outputs = keras_hub.layers.TransformerEncoder(
    intermediate_dim=INTERMEDIATE_DIM, num_heads=NUM_HEADS
)(inputs=x)
encoder = keras.Model(encoder_inputs, encoder_outputs)


# Decoder
decoder_inputs = keras.Input(shape=(None,), name="decoder_inputs")
encoded_seq_inputs = keras.Input(shape=(None, EMBED_DIM), name="decoder_state_inputs")

x = keras_hub.layers.TokenAndPositionEmbedding(
    vocabulary_size=ASL_VOCAB_SIZE,
    sequence_length=MAX_SEQUENCE_LENGTH,
    embedding_dim=EMBED_DIM,
)(decoder_inputs)

x = keras_hub.layers.TransformerDecoder(
    intermediate_dim=INTERMEDIATE_DIM, num_heads=NUM_HEADS
)(decoder_sequence=x, encoder_sequence=encoded_seq_inputs)
x = keras.layers.Dropout(0.5)(x)
decoder_outputs = keras.layers.Dense(ASL_VOCAB_SIZE, activation="softmax")(x)
decoder = keras.Model(
    [
        decoder_inputs,
        encoded_seq_inputs,
    ],
    decoder_outputs,
)
decoder_outputs = decoder([decoder_inputs, encoder_outputs])

transformer = keras.Model(
    [encoder_inputs, decoder_inputs],
    decoder_outputs,
    name="transformer",
)

In [None]:
transformer.summary()
transformer.compile(
    "rmsprop", loss="sparse_categorical_crossentropy", metrics=["accuracy"]
)
transformer.fit(train_ds, epochs=EPOCHS, validation_data=val_ds)

In [None]:
def decode_sequences(input_sentences):
    batch_size = 1

    # Tokenize the encoder input.
    encoder_input_tokens = ops.convert_to_tensor(eng_tokenizer(input_sentences))
    if len(encoder_input_tokens[0]) < MAX_SEQUENCE_LENGTH:
        pads = ops.full((1, MAX_SEQUENCE_LENGTH - len(encoder_input_tokens[0])), 0)
        encoder_input_tokens = ops.concatenate(
            [encoder_input_tokens.to_tensor(), pads], 1
        )

    # Define a function that outputs the next token's probability given the
    # input sequence.
    def next(prompt, cache, index):
        logits = transformer([encoder_input_tokens, prompt])[:, index - 1, :]
        # Ignore hidden states for now; only needed for contrastive search.
        hidden_states = None
        return logits, hidden_states, cache

    # Build a prompt of length 40 with a start token and padding tokens.
    length = 40
    start = ops.full((batch_size, 1), asl_tokenizer.token_to_id("[START]"))
    pad = ops.full((batch_size, length - 1), asl_tokenizer.token_to_id("[PAD]"))
    prompt = ops.concatenate((start, pad), axis=-1)

    generated_tokens = keras_hub.samplers.GreedySampler()(
        next,
        prompt,
        stop_token_ids=[asl_tokenizer.token_to_id("[END]")],
        index=1,  # Start sampling after start token.
    )
    generated_sentences = asl_tokenizer.detokenize(generated_tokens)
    return generated_sentences


test_eng_texts = [pair[0] for pair in test_pairs]
for i in range(2):
    input_sentence = random.choice(test_eng_texts)
    translated = decode_sequences([input_sentence])
    translated = translated.numpy()[0].decode("utf-8")
    translated = (
        translated.replace("[PAD]", "")
        .replace("[START]", "")
        .replace("[END]", "")
        .strip()
    )
    print(f"** Example {i} **")
    print(input_sentence)
    print(translated)
    print()

In [None]:
rouge_1 = keras_hub.metrics.RougeN(order=1)
rouge_2 = keras_hub.metrics.RougeN(order=2)

for test_pair in test_pairs[:30]:
    input_sentence = test_pair[0]
    reference_sentence = test_pair[1]

    translated_sentence = decode_sequences([input_sentence])
    translated_sentence = translated_sentence.numpy()[0].decode("utf-8")
    translated_sentence = (
        translated_sentence.replace("[PAD]", "")
        .replace("[START]", "")
        .replace("[END]", "")
        .strip()
    )

    rouge_1(reference_sentence, translated_sentence)
    rouge_2(reference_sentence, translated_sentence)

print("ROUGE-1 Score: ", rouge_1.result())
print("ROUGE-2 Score: ", rouge_2.result())