In [1]:
import keras_hub
import random
import requests

import keras
from keras import ops

import tensorflow.data as tf_data
import tensorflow as tf
from tensorflow_text.tools.wordpiece_vocab import (
    bert_vocab_from_dataset,
)
import pandas as pd
from keras_nlp.samplers import TopKSampler
from sklearn.metrics import f1_score


import numpy as np
import os
from pathlib import Path

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# imports for pre-parsing
from pyparsing import Word, alphas as pp_alpha, nums as pp_nums
import pyparsing as pp
pp.ParserElement.enablePackrat()
import re

In [3]:
class DictTokenizer:
    def __init__(self, vocab, tokenizer_fn):
        self.token_to_id_map = vocab
        self.id_to_token_map = {i: t for t, i in vocab.items()}
        self.tokenizer_fn = tokenizer_fn

    def __call__(self, text_batch):
        return [
            [self.token_to_id_map.get(tok, self.token_to_id_map.get("[UNK]", 0)) 
             for tok in self.tokenizer_fn(text)]
            for text in text_batch
        ]

    def tokenize(self, text):
        return [self.token_to_id_map.get(tok, self.token_to_id_map.get("[UNK]", 0)) 
                for tok in self.tokenizer_fn(text)]

    def detokenize(self, token_ids):
        if isinstance(token_ids, tf.Tensor):
            token_ids = token_ids.numpy()
        elif isinstance(token_ids, tf.RaggedTensor):
            token_ids = token_ids.to_tensor().numpy()
        elif isinstance(token_ids, int):
            token_ids = [token_ids]

        return " ".join([self.id_to_token_map.get(int(tok_id), "[UNK]") for tok_id in token_ids])

    def token_to_id(self, token):
        return self.token_to_id_map.get(token, self.token_to_id_map.get("[UNK]", 0))


In [4]:
# regex rules

alpha_regexp = r"""
(?!((?:THUMB-)?(?:IX|POSS|SELF)))   # negative lookahead for blocked glosses
[A-Z]                               # must start with uppercase
(?:                                 # optional middle section
    (?:                             # non-capturing group for allowed connectors
        (?:[-/][A-Z])               # hyphen or slash must be followed by uppercase
      | (?:_[0-9])                  # underscore must be followed by digit
      | (?:\+(?:[A-Z#]|fs-))       # plus + (uppercase OR # OR the literal fs-)
      | [A-Z0-9]                    # regular letter/digit continuation
    )
)*                                  # repeatable
(?:\.)?                             # optional trailing period
"""

In [5]:
# conventions kept for parsing

cl_prefix = pp.one_of(["CL", "DCL", "LCL", "SCL", "BCL", "BPCL", "PCL", "ICL"])
fs_prefix = pp.Literal("fs-")
index_core_ix = pp.Literal("IX")
other_index_core = pp.one_of(["POSS", "SELF"])
hashtag = pp.Literal("#")
dash = pp.Literal("-")
contraction = pp.Literal("^")
period = pp.Literal(".")
alpha = pp.Word(pp_alpha, max=1)
num = pp.Word(pp_nums, max=1)
word = pp.Regex(alpha_regexp, flags=re.X)

In [6]:
# grammar rules

full_grammar = pp.OneOrMore(
    fs_prefix |               # fingerspelling fs
    word |
    cl_prefix |               # classifiers like CL, DCL, etc.
    index_core_ix |           # IX
    other_index_core |        # POSS, SELF
    hashtag |                 # #
    contraction |             # ^
    period |                  # .
    dash |
    num |
    alpha                     # fallback LAST
)

In [7]:
# tokenize based on predefined grammar rules

def custom_asl_tokenize(text):
    try:
        if "'" in text:
            text = text.replace("'", "")
        if "++" in text:
            text = text.replace("++", "+")
        return full_grammar.parse_string(text, parse_all=True).asList()
    except pp.ParseException as pe:
        print(text)
        print(f"Failed to parse: {pe}")
        return []

In [8]:
def custom_eng_tokenize(text):
    # Perserve punctuation and digits
    text = re.sub(r'([^\w\s]|\d)', r' \1 ', text)
    # Convert to lowercase
    text = text.lower()
    # Split on whitespace
    tokens = text.split()
    return tokens

In [9]:
# generate
    # 1) list of eng-asl sentence pairs
    # 2) set of unique english vocab
    # 3) set of unique asl vocab
data_path = "/Users/adrianajimenez/Desktop/Downloads/REUAICT/Real-Code/2025-ASL-data/sent_pairs_joined.txt"
    
text_pairs = []
eng_texts = []
asl_texts = []
SPECIAL_TOKENS = ["[PAD]", "[START]", "[END]", "[UNK]"]
eng_tokens = set(SPECIAL_TOKENS)
asl_tokens = set(SPECIAL_TOKENS)
max_length = 0

with open(data_path, "r", encoding="utf-8") as f:
    lines = f.read().split("\n")

for line in lines:
    pair = []
    eng_text, asl_text = line.split("\t")
    eng_texts.append(eng_text)
    asl_texts.append(asl_text)
    pair.append(eng_text.lower())
    pair.append(asl_text)
    text_pairs.append(pair)
    
for text in eng_texts:
    tokens = custom_eng_tokenize(text)
    length = len(tokens)
    if length > max_length:
        max_length = length
    for token in tokens:
        if token not in eng_tokens:
                eng_tokens.add(token)
            
for text in asl_texts:
    tokens = custom_asl_tokenize(text)
    length = len(tokens)
    if length > max_length:
        max_length = length
    for token in tokens:
        if token not in asl_tokens:
                asl_tokens.add(token)
                            
max_encoder_seq_length = max([len(txt) for txt in eng_texts])
max_decoder_seq_length = max([len(txt) for txt in asl_texts])

eng_tokens = sorted(list(eng_tokens))
asl_tokens = sorted(list(asl_tokens))

print("eng_tokens:", eng_tokens)
print("asl_tokens", asl_tokens)
num_encoder_tokens = len(eng_tokens)
num_decoder_tokens = len(asl_tokens)
print("num_eng_tokens", num_encoder_tokens)
print("num_asl_tokens", num_decoder_tokens)

asl_tokens ['#', '-', '.', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'A-LEVEL-ABOVE', 'A-LEVEL-BELOW', 'A-LOT', 'A-OK', 'A-WAYS', 'AA', 'AAA', 'ABANDON', 'ABC', 'ABOUT', 'ABOVE', 'ABOVE_3', 'ABUSE', 'AC', 'ACCEPT', 'ACCIDENT', 'ACCOMMODATE', 'ACROSS', 'ACT', 'ACT+AGENT', 'ACTION', 'ADD-TO', 'ADDICTED', 'ADMIT', 'ADULT-TALL', 'ADVANTAGE', 'ADVENTURE', 'ADVISE', 'ADVISE/INFLUENCE', 'ADVISE/INFLUENCE+AGENT', 'ADVISER', 'AFRAID', 'AFTER', 'AFTERNOON_3', 'AGAIN', 'AGAINST', 'AGE', 'AGE-SIX+HALF', 'AGE-THIRTEEN', 'AGE-TWENTY-ONE', 'AGE-TWENTY-ONE_2', 'AGENT', 'AGREE', 'AIR', 'AIRPLANE', 'ALARM', 'ALCOHOL', 'ALEC-BALDWIN', 'ALI', 'ALL', 'ALL-DAY', 'ALL-GONE', 'ALL-NIGHT', 'ALL-NIGHT_3', 'ALL-THE-WAY', 'ALL-YEARS-HS', 'ALLERGY', 'ALLOW', 'ALL_2', 'ALMOST', 'ALONE', 'ALRIGHT', 'ALSO', 'ALWAYS', 'AMERICAN-AIRLINES', 'AMONG', 'AMONG_2', 'AMY', 'ANALYZE', 'ANALYZE_2', 'ANALYZE_3', 'AND', 'ANGELA', 'ANGRY', 'ANIMAL', 'ANKLE', 'ANN', 'ANNOUNCE', 'ANSWER', 'ANSWER+AGENT', 'ANY', 'ANY+MOR

In [10]:
main_asl_glosses = set()
split_pattern = r"[\/\+\-]"
        
for token in asl_tokens:
    parts = re.split(split_pattern, token)
    if all(part.isalpha() and part.isupper() for part in parts):
        main_asl_glosses.add(token)

In [11]:
# model parameters / hyperparameters

BATCH_SIZE = 16
EPOCHS = 20
EMBED_DIM = 64
INTERMEDIATE_DIM = 128
NUM_HEADS = 4
MAX_SEQUENCE_LENGTH = max_length

In [12]:
# glimpse pairs

for _ in range(5):
    print(random.choice(text_pairs))

["that kind of sound has a different meaning. what happens in situations with grants is that there can be many different cultural situations where there are uncontrollable voices, voice excess or grunts, and it's not acceptable. it really blocks your communication process.", 'THAT KIND NOISE HAVE DIFFERENT MEAN AND HAPPEN SITUATION fs-GRUNT MANY IX CULTURE SITUATION WHERE fs-UN MANAGE/CONTROL VOICE VOICE fs-EXCESS fs-GRUNT AMONG_2 NOT ACCEPT NOT AND REALLY PREVENT POSS PROGRESS 1 BUOY']
["i like it, if it has butter, i'm fine with that.", 'IX LIKE IX #IF IX HAVE BUTTER IN FINE IX']
["i didn't realize that nebraska was really that big! but anyway, so i was driving...", 'IX REALIZE fs-NEB REALLY THAT BIG BUT  IX DRIVE']
['john should give (his) father a car.', 'fs-JOHN SHOULD GIFT FATHER CAR']
['mother should not buy that car.', 'IX MOTHER IX CAR BUY SHOULD NOT']


In [13]:
# split data

random.shuffle(text_pairs)
num_val_samples = int(0.15 * len(text_pairs))
num_train_samples = len(text_pairs) - 2 * num_val_samples
train_pairs = text_pairs[:num_train_samples]
val_pairs = text_pairs[num_train_samples : num_train_samples + num_val_samples]
test_pairs = text_pairs[num_train_samples + num_val_samples :]

print(f"{len(text_pairs)} total pairs")
print(f"{len(train_pairs)} training pairs")
print(f"{len(val_pairs)} validation pairs")
print(f"{len(test_pairs)} test pairs")

3389 total pairs
2373 training pairs
508 validation pairs
508 test pairs


In [14]:

eng_vocab = dict([(char, i) for i, char in enumerate(eng_tokens)])
asl_vocab = dict([(char, i) for i, char in enumerate(asl_tokens)])

eng_tokenizer = DictTokenizer(eng_vocab, tokenizer_fn=custom_eng_tokenize)
asl_tokenizer = DictTokenizer(asl_vocab, tokenizer_fn=custom_asl_tokenize)

print(eng_tokenizer)
print(asl_tokenizer)

print(eng_vocab)
print(asl_vocab)

<__main__.DictTokenizer object at 0x103d12590>
<__main__.DictTokenizer object at 0x103d12290>
{'#': 0, '-': 1, '.': 2, '0': 3, '1': 4, '2': 5, '3': 6, '4': 7, '5': 8, '6': 9, '7': 10, '8': 11, '9': 12, 'A': 13, 'A-LEVEL-ABOVE': 14, 'A-LEVEL-BELOW': 15, 'A-LOT': 16, 'A-OK': 17, 'A-WAYS': 18, 'AA': 19, 'AAA': 20, 'ABANDON': 21, 'ABC': 22, 'ABOUT': 23, 'ABOVE': 24, 'ABOVE_3': 25, 'ABUSE': 26, 'AC': 27, 'ACCEPT': 28, 'ACCIDENT': 29, 'ACCOMMODATE': 30, 'ACROSS': 31, 'ACT': 32, 'ACT+AGENT': 33, 'ACTION': 34, 'ADD-TO': 35, 'ADDICTED': 36, 'ADMIT': 37, 'ADULT-TALL': 38, 'ADVANTAGE': 39, 'ADVENTURE': 40, 'ADVISE': 41, 'ADVISE/INFLUENCE': 42, 'ADVISE/INFLUENCE+AGENT': 43, 'ADVISER': 44, 'AFRAID': 45, 'AFTER': 46, 'AFTERNOON_3': 47, 'AGAIN': 48, 'AGAINST': 49, 'AGE': 50, 'AGE-SIX+HALF': 51, 'AGE-THIRTEEN': 52, 'AGE-TWENTY-ONE': 53, 'AGE-TWENTY-ONE_2': 54, 'AGENT': 55, 'AGREE': 56, 'AIR': 57, 'AIRPLANE': 58, 'ALARM': 59, 'ALCOHOL': 60, 'ALEC-BALDWIN': 61, 'ALI': 62, 'ALL': 63, 'ALL-DAY': 64, 'ALL-

In [15]:
eng_input_ex = text_pairs[0][0]
eng_tokens_ex = eng_tokenizer.tokenize(eng_input_ex)
print("English sentence: ", eng_input_ex)
print("Tokens: ", eng_tokens_ex)
print(
    "Recovered text after detokenizing: ",
    eng_tokenizer.detokenize(eng_tokens_ex),
)

print()

asl_input_ex = text_pairs[0][1]
asl_tokens_ex = asl_tokenizer.tokenize(asl_input_ex)
print("ASL sentence: ", asl_input_ex)
print("Tokens: ", asl_tokens_ex)
print(
    "Recovered text after detokenizing: ",
    asl_tokenizer.detokenize(asl_tokens_ex),
)

English sentence:  my parents gave birth to dana.
Tokens:  [1762, 1926, 1110, 291, 2775, 691, 9]
Recovered text after detokenizing:  my parents gave birth to dana .

ASL sentence:  REALLY POSS MOTHER+FATHER IX BORN fs-DANA
Tokens:  [1507, 1435, 1178, 950, 228, 2117, 452]
Recovered text after detokenizing:  REALLY POSS MOTHER+FATHER IX BORN fs- DANA


In [16]:
def preprocess_batch(eng, asl):
    eng_start_end_packer = keras_hub.layers.StartEndPacker(
        sequence_length=MAX_SEQUENCE_LENGTH,
        pad_value=eng_tokenizer.token_to_id("[PAD]"),
        dtype="int32"
    )
    eng = eng_start_end_packer(eng)

    asl_start_end_packer = keras_hub.layers.StartEndPacker(
        sequence_length=MAX_SEQUENCE_LENGTH + 1,
        start_value=asl_tokenizer.token_to_id("[START]"),
        end_value=asl_tokenizer.token_to_id("[END]"),
        pad_value=asl_tokenizer.token_to_id("[PAD]"),
        dtype="int32"
    )
    asl = asl_start_end_packer(asl)

    decoder_inputs = asl[:, :-1]
    decoder_outputs = asl[:, 1:]

    return {
        "encoder_inputs": eng,
        "decoder_inputs": decoder_inputs
    }, decoder_outputs


In [17]:
def make_dataset(pairs):
    
    eng_ids = [eng_tokenizer.tokenize(sent) for sent, _ in pairs]    
    asl_ids = [asl_tokenizer.tokenize(sent) for _, sent in pairs]

    # 🛠️ Force token type to int32
    eng_tensor = tf.ragged.constant(eng_ids, dtype=tf.int32)
    asl_tensor = tf.ragged.constant(asl_ids, dtype=tf.int32)
    
    dataset = tf_data.Dataset.from_tensor_slices((eng_tensor, asl_tensor))
    dataset = dataset.batch(BATCH_SIZE)
    dataset = dataset.map(preprocess_batch, num_parallel_calls=tf_data.AUTOTUNE)
    return dataset.shuffle(2048).prefetch(16).cache()

train_ds = make_dataset(train_pairs)
val_ds = make_dataset(val_pairs)
print(train_ds)

2025-07-18 17:13:50.429421: I metal_plugin/src/device/metal_device.cc:1154] Metal device set to: Apple M2
2025-07-18 17:13:50.429817: I metal_plugin/src/device/metal_device.cc:296] systemMemory: 8.00 GB
2025-07-18 17:13:50.430293: I metal_plugin/src/device/metal_device.cc:313] maxCacheSize: 2.67 GB
I0000 00:00:1752873230.431253 14830960 pluggable_device_factory.cc:305] Could not identify NUMA node of platform GPU ID 0, defaulting to 0. Your kernel may not have been built with NUMA support.
I0000 00:00:1752873230.432224 14830960 pluggable_device_factory.cc:271] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 0 MB memory) -> physical PluggableDevice (device: 0, name: METAL, pci bus id: <undefined>)


<CacheDataset element_spec=({'encoder_inputs': TensorSpec(shape=(None, 71), dtype=tf.int32, name=None), 'decoder_inputs': TensorSpec(shape=(None, 71), dtype=tf.int32, name=None)}, TensorSpec(shape=(None, 71), dtype=tf.int32, name=None))>


In [18]:
for inputs, targets in train_ds.take(1):
    print(f'inputs["encoder_inputs"].shape: {inputs["encoder_inputs"].shape}')
    print(f'inputs["decoder_inputs"].shape: {inputs["decoder_inputs"].shape}')
    print(f"targets.shape: {targets.shape}")

inputs["encoder_inputs"].shape: (16, 71)
inputs["decoder_inputs"].shape: (16, 71)
targets.shape: (16, 71)


2025-07-18 17:13:52.304718: W tensorflow/core/kernels/data/cache_dataset_ops.cc:916] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
2025-07-18 17:13:52.306604: I tensorflow/core/framework/local_rendezvous.cc:407] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


In [19]:

# Encoder
encoder_inputs = keras.Input(shape=(None,), name="encoder_inputs")

x = keras_hub.layers.TokenAndPositionEmbedding(
    vocabulary_size=num_encoder_tokens,
    sequence_length=MAX_SEQUENCE_LENGTH,
    embedding_dim=EMBED_DIM,
)(encoder_inputs)

encoder_outputs = keras_hub.layers.TransformerEncoder(
    intermediate_dim=INTERMEDIATE_DIM, num_heads=NUM_HEADS
)(inputs=x)
encoder = keras.Model(encoder_inputs, encoder_outputs)


# Decoder
decoder_inputs = keras.Input(shape=(None,), name="decoder_inputs")
encoded_seq_inputs = keras.Input(shape=(None, EMBED_DIM), name="decoder_state_inputs")

x = keras_hub.layers.TokenAndPositionEmbedding(
    vocabulary_size=num_decoder_tokens,
    sequence_length=MAX_SEQUENCE_LENGTH,
    embedding_dim=EMBED_DIM,
)(decoder_inputs)

x = keras_hub.layers.TransformerDecoder(
    intermediate_dim=INTERMEDIATE_DIM, num_heads=NUM_HEADS
)(decoder_sequence=x, encoder_sequence=encoded_seq_inputs)
x = keras.layers.Dropout(0.5)(x)
decoder_outputs = keras.layers.Dense(num_decoder_tokens, activation="softmax")(x)
decoder = keras.Model(
    [
        decoder_inputs,
        encoded_seq_inputs,
    ],
    decoder_outputs,
)
decoder_outputs = decoder([decoder_inputs, encoder_outputs])

transformer = keras.Model(
    [encoder_inputs, decoder_inputs],
    decoder_outputs,
    name="transformer",
)

In [20]:
optimizer = keras.optimizers.Adam(learning_rate=1e-4)

transformer.summary()
transformer.compile(
    optimizer, loss="sparse_categorical_crossentropy", metrics=["sparse_categorical_accuracy"]
)
transformer.fit(train_ds, epochs=EPOCHS, validation_data=val_ds)

Epoch 1/20


2025-07-18 17:13:57.147536: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:117] Plugin optimizer for device_type GPU is enabled.


[1m149/149[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m20s[0m 97ms/step - loss: 6.2773 - sparse_categorical_accuracy: 0.6007 - val_loss: 4.3628 - val_sparse_categorical_accuracy: 0.8628
Epoch 2/20
[1m149/149[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m12s[0m 83ms/step - loss: 3.9343 - sparse_categorical_accuracy: 0.8651 - val_loss: 2.6160 - val_sparse_categorical_accuracy: 0.8628
Epoch 3/20
[1m149/149[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m12s[0m 79ms/step - loss: 2.3509 - sparse_categorical_accuracy: 0.8651 - val_loss: 1.5188 - val_sparse_categorical_accuracy: 0.8628
Epoch 4/20
[1m149/149[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m12s[0m 81ms/step - loss: 1.4834 - sparse_categorical_accuracy: 0.8654 - val_loss: 1.0810 - val_sparse_categorical_accuracy: 0.8633
Epoch 5/20
[1m149/149[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m12s[0m 78ms/step - loss: 1.1256 - sparse_categorical_accuracy: 0.8702 - val_loss: 0.9201 - val_sparse_categorical_accuracy: 0.8752
E

<keras.src.callbacks.history.History at 0x16dac60e0>

In [None]:
class CustomSampler1(keras_hub.samplers.Sampler):
    def __init__(self, sentence, vocab,
                 **kwargs):
        super().__init__(**kwargs)
        self.input_words    = [w.upper() for w in sentence if len(w) > 2]
        self.vocab          = vocab
        self.prefered_boost = 2.0
        self.allowed_boost  = 1.5
        self.temperature    = 0.8
        self.top_k          = 10               # store it
        self.max_repeats    = 2
        # precompute once
        self.allowed_ids    = self.compute_allowed_ids("allowed")
        self.prefered_ids   = self.compute_allowed_ids("prefer")
        self.prev_token     = None
        self.repeat_count   = 0

    def compute_allowed_ids(self, preference):
        prefer, allowed = set(), set()
        for word in self.input_words:
            p, s = word[:3], word[-3:]
            for gloss, gid in self.vocab.items():
                if gloss.startswith(p) and gloss.endswith(s):
                    prefer.add(gid)
                elif gloss.startswith(p) or gloss.endswith(s):
                    allowed.add(gid)
        return prefer if preference=="prefer" else allowed

    def get_next_token(self, logits):
        # 1) Boost matching gloss logits
        for gid in self.allowed_ids:
            boosted = logits[0, gid] + self.allowed_boost
            logits  = keras.ops.slice_update(
                logits, (0, gid),
                keras.ops.reshape(boosted, (1,1))
            )
        for gid in self.prefered_ids:
            boosted = logits[0, gid] + self.prefered_boost
            logits  = keras.ops.slice_update(
                logits, (0, gid),
                keras.ops.reshape(boosted, (1,1))
            )

        # 2) Repetition penalty
        if self.prev_token is not None and self.repeat_count >= self.max_repeats:
            neg_inf = tf.constant([[-1e9]], dtype=logits.dtype)
            logits  = keras.ops.slice_update(logits, (0, self.prev_token), neg_inf)

        # 3) Temperature scaling
        scaled_logits = logits / tf.cast(self.temperature, logits.dtype)

        # 4) Top‑k filtering via threshold
        topk_vals, _ = tf.math.top_k(scaled_logits, k=self.top_k)
        threshold    = tf.reduce_min(topk_vals, axis=-1, keepdims=True)  # [batch,1]
        neg_inf      = tf.constant(-1e9, dtype=scaled_logits.dtype)
        filtered     = tf.where(
            scaled_logits < threshold,
            neg_inf,
            scaled_logits
        )

        # 5) Sample from the filtered logits
        # tf.random.categorical expects 2D [batch, vocab_size]
        sample_id = tf.random.categorical(filtered, num_samples=1)  # [batch,1]
        token     = tf.squeeze(sample_id, axis=1)                  # [batch]

        # 6) Update repeat tracking
        tid = int(token.numpy()[0])
        if tid == self.prev_token:
            self.repeat_count += 1
        else:
            self.prev_token   = tid
            self.repeat_count = 1

        return token  # shape [1]


In [23]:
class CustomSampler2(keras_hub.samplers.Sampler):
    def __init__(self, sentence, vocab,
                 key = "70426f32-2291-4d4e-9f6d-a220bc1a93f6",
                 lang = "EN",
                 target_lan = "SIMPLE",
                 senses_url = "https://babelnet.io/v9/getSenses?",
                 **kwargs):
        super().__init__(**kwargs)
        self.input_words    = [w.upper() for w in sentence if len(w) > 2]
        self.vocab          = vocab
        self.prefered_boost = 2.0
        self.allowed_boost  = 1.5
        self.temperature    = 0.8
        self.top_k          = 10               # store it
        self.max_repeats    = 2
        self.url = senses_url
        self.key = key
        self.target = target_lan
        self.lang = lang
        # precompute once
        self.allowed_ids    = self.compute_allowed_ids("allowed")
        self.prefered_ids   = self.compute_allowed_ids("prefer")
        self.prev_token     = None
        self.repeat_count   = 0
        
    def fetch_senses(self, word):
        """Fetch all lemmas for `word` from BabelNet."""
        synset_params = {"lemma": word, "searchLang": self.lang, "targetLang": self.target, "key": self.key}
        try:
            resp = requests.get(self.url, params=synset_params)
            resp.raise_for_status()
            senses = resp.json()
            print(senses)
        except Exception as e:
            print(f"BabelNet fetch error for '{word}':", e)
            raise

        lemmas = set()
        for sense in senses:
            props = sense.get("properties", {})
            lemma = props.get("simpleLemma")
            print(lemma)
            if isinstance(lemma, str) and len(lemma) > 4:
                lemmas.add(lemma.upper())
        return lemmas

    def compute_allowed_ids(self, preference):
        prefer, allowed = set(), set()
        for word in self.input_words:
            senses = self.fetch_senses(word)
            print(senses)
            for each in senses:
                underscores = each.count("_")
                if underscores < 2:
                    print(each)
                    p, s = each[:5], each[-5:]
                    print(p)
                    print(s)
                    for gloss, gid in self.vocab.items():
                        if gloss.startswith(p) and gloss.endswith(s):
                            prefer.add(gid)
                        elif gloss.startswith(p) or gloss.endswith(s):
                            allowed.add(gid)
        print(prefer)
        print(allowed)
        return prefer if preference=="prefer" else allowed

    def get_next_token(self, logits):
        # 1) Boost matching gloss logits
        for gid in self.allowed_ids:
            print(gid)
            boosted = logits[0, gid] + self.allowed_boost
            logits  = keras.ops.slice_update(
                logits, (0, gid),
                keras.ops.reshape(boosted, (1,1))
            )
        for gid in self.prefered_ids:
            print(gid)
            boosted = logits[0, gid] + self.prefered_boost
            logits  = keras.ops.slice_update(
                logits, (0, gid),
                keras.ops.reshape(boosted, (1,1))
            )

        # 2) Repetition penalty
        if self.prev_token is not None and self.repeat_count >= self.max_repeats:
            neg_inf = tf.constant([[-1e9]], dtype=logits.dtype)
            logits  = keras.ops.slice_update(logits, (0, self.prev_token), neg_inf)

        # 3) Temperature scaling
        scaled_logits = logits / tf.cast(self.temperature, logits.dtype)

        # 4) Top‑k filtering via threshold
        topk_vals, _ = tf.math.top_k(scaled_logits, k=self.top_k)
        threshold    = tf.reduce_min(topk_vals, axis=-1, keepdims=True)  # [batch,1]
        neg_inf      = tf.constant(-1e9, dtype=scaled_logits.dtype)
        filtered     = tf.where(
            scaled_logits < threshold,
            neg_inf,
            scaled_logits
        )

        # 5) Sample from the filtered logits
        # tf.random.categorical expects 2D [batch, vocab_size]
        sample_id = tf.random.categorical(filtered, num_samples=1)  # [batch,1]
        token     = tf.squeeze(sample_id, axis=1)                  # [batch]

        # 6) Update repeat tracking
        tid = int(token.numpy()[0])
        if tid == self.prev_token:
            self.repeat_count += 1
        else:
            self.prev_token   = tid
            self.repeat_count = 1

        return token  # shape [1]

In [24]:
def decode_sequences(input_sentences):
    with tf.device('/CPU:0'):
        batch_size     = 1
        prompt_length  = MAX_SEQUENCE_LENGTH   # still pad encoder/prompt to fixed size

        # — Encoder input preparation (unchanged) —
        encoder_inputs = ops.convert_to_tensor(eng_tokenizer(input_sentences))
        seq_len = tf.shape(encoder_inputs)[1]
        if seq_len < prompt_length:
            pad_amt = prompt_length - seq_len
            pads    = ops.full((1, pad_amt), 0, dtype=encoder_inputs.dtype)
            encoder_inputs = ops.concatenate([encoder_inputs, pads], axis=1)

        # — Compute dynamic decode lengths —
        input_ids      = eng_tokenizer(input_sentences)[0]
        input_len      = len(input_ids)
        max_decode_len = min(prompt_length, input_len)               # at most input+2
        min_decode_len = max(1, int(input_len * 0.75))                   # at least 80% of input

        # — Initialize your sampler as before —
        sampler = CustomSampler2(
            sentence=input_sentences[0].split(),
            vocab=asl_tokenizer.token_to_id_map,
        )

        # — Build initial prompt ([START] + PADs) —
        start_id = asl_tokenizer.token_to_id("[START]")
        pad_id   = asl_tokenizer.token_to_id("[PAD]")
        prompt   = ops.full((batch_size, prompt_length), pad_id, dtype=tf.int32)
        start_t  = ops.convert_to_tensor([[start_id]], dtype=tf.int32)
        prompt   = ops.slice_update(prompt, (0, 0), start_t)

        # — Next‑token fn (unchanged) —
        def next_fn(pr, cache, idx):
            logits = transformer([encoder_inputs, pr])[:, idx - 1, :]
            return logits, None, cache

        cache     = None
        generated = []

        # — Decoding loop with dynamic bounds —
        for idx in range(1, max_decode_len):
            logits, _, cache = next_fn(prompt, cache, idx)
            token = sampler.get_next_token(logits)                # shape [1]
            token = ops.cast(token, dtype=prompt.dtype)
            prompt = ops.slice_update(prompt, (0, idx), ops.expand_dims(token, 0))

            tok_id = int(token.numpy()[0])
            generated.append(tok_id)

            # only stop if we’ve hit [END] *and* run at least min_decode_len steps
            if tok_id == asl_tokenizer.token_to_id("[END]") and idx >= min_decode_len:
                break

        return asl_tokenizer.detokenize(generated)


outputs = []
test_eng_texts = [pair[0] for pair in test_pairs]

for i in range(3):
    output_pairs = []
    input_sentence = random.choice(test_eng_texts)
    translated = decode_sequences([input_sentence])
    translated = (
        translated.replace("[PAD]", "")
        .replace("[START]", "")
        .replace("[END]", "")
        .strip()
    )
    output_pairs.append(input_sentence)
    output_pairs.append(translated)
    print(output_pairs)
    outputs.append(output_pairs)

df = pd.DataFrame(outputs, columns=["input sentence", "translation"])
df.to_csv("/Users/adrianajimenez/Desktop/Downloads/REUAICT/Real-Code/2025-ASL-data/seq2seq_code/word_level/babelnet_sampler1.txt", index=False)


[]
set()
[{'type': 'BabelSense', 'properties': {'fullLemma': 'T.H.E.', 'simpleLemma': 't.h.e.', 'lemma': {'lemma': 'T.H.E.', 'type': 'HIGH_QUALITY'}, 'source': 'BABELNET', 'senseKey': 'bn:02738756n', 'frequency': 0, 'language': 'SIMPLE', 'pos': 'NOUN', 'synsetID': {'id': 'bn:02738756n', 'pos': 'NOUN', 'source': 'BABELNET'}, 'pronunciations': {'audios': [], 'transcriptions': []}, 'bKeySense': False, 'idSense': 0, 'tags': {}}}, {'type': 'BabelSense', 'properties': {'fullLemma': 'The_the', 'simpleLemma': 'the_the', 'lemma': {'lemma': 'The_the', 'type': 'HIGH_QUALITY'}, 'source': 'BABELNET', 'senseKey': 'bn:00832768n', 'frequency': 0, 'language': 'SIMPLE', 'pos': 'NOUN', 'synsetID': {'id': 'bn:00832768n', 'pos': 'NOUN', 'source': 'BABELNET'}, 'pronunciations': {'audios': [], 'transcriptions': []}, 'bKeySense': False, 'idSense': 0, 'tags': {}}}, {'type': 'BabelSense', 'properties': {'fullLemma': 'Ҫ', 'simpleLemma': 'ҫ', 'lemma': {'lemma': 'Ҫ', 'type': 'HIGH_QUALITY'}, 'source': 'BABELNET', 