In [None]:
import json
import pandas as pd
import re
import numpy as np
from tqdm.auto import tqdm
import tensorflow as tf
import matplotlib.pyplot as plt
import json
import cv2
from tensorflow.keras.callbacks import (
    EarlyStopping,
    ModelCheckpoint,
)

In [None]:
# Step 1: Load the raw JSON structure
with open('GQA/questions1.2/val_balanced_questions.json', 'r') as f:
    data = json.load(f)

df = pd.DataFrame(data).T[['question','fullAnswer','imageId']].reset_index(drop=True)
df = df.rename(columns={'question':'input_text','fullAnswer':'target_text','imageId':'img_path'})
df['img_path'] = df['img_path'].apply(lambda x: 'GQA/images/'+x+'.jpg')
df

# Cleaning !

In [None]:
# 1) Text cleaning ------------------------------------------------------
def clean_text(text: str) -> str:
    """
    Remove HTML, normalize whitespace, preserve punctuation/numbers/casing.
    """
    # text = text.lower()
    text = re.sub(r'[^a-zA-Z0-9.,!?]+', ' ', text)
    text = text.replace('\r', ' ').replace('\n', ' ')
    text = re.sub(r'\s+', ' ', text)
    return text.strip().lower()

df['target_text'] = df['target_text'].apply(lambda x: clean_text(x))
df['input_text'] = df['input_text'].apply(lambda x: clean_text(x))

In [None]:
# Combine summary + dialogue and split on whitespace
raw_lens = [
    len(f"{s} {d}".split()) 
    for s, d in tqdm(zip(df['target_text'], df['input_text']), total=len(df))
]

lens = np.array(raw_lens)

# Print summary stats
def pct(x): return np.percentile(lens, x)

print(f"Total examples    : {len(lens):,}")
print(f"Min / Max words   : {lens.min()} / {lens.max()}")
print(f"Mean ± std        : {lens.mean():.1f} ± {lens.std():.1f}")
print("--- Percentiles (word count per raw text pair) ---")
for p in [50, 90, 95, 98, 99]:
    print(f"{p:>3}% : {pct(p):.0f} words")

del(raw_lens)

In [None]:
MAX_LEN = int(pct(99))

text_pairs = []

for img,i,j in zip(df.img_path,df.input_text,df.target_text):
    try:
        if len(i.split(" ")+j.split(" ")) < MAX_LEN: 
            text_pairs.append((img,i,j))
    except:
        pass

text_pairs[:5]

In [None]:
len(text_pairs)

# Tokenizing Pre-Trained 

In [None]:
from transformers import AlbertTokenizer, TFAlbertModel

tokenizer = AlbertTokenizer.from_pretrained('albert-base-v2')

albert = TFAlbertModel.from_pretrained("albert-base-v2")
token_model = albert.get_input_embeddings()   
token_model.trainable = False  

# IMAGE EMBEDDER

In [None]:
from tensorflow.keras.applications import MobileNetV3Small
import tensorflow as tf

# Desired network input size
IMG_SHAPE = (224, 224, 3)

def build_image_encoder():
    base = MobileNetV3Small(input_shape=IMG_SHAPE,
                            weights=None,
                            include_top=False,
                            pooling=None)
    base.load_weights("weights_mobilenet_v3_small_224_1.0_float_no_top_v2.h5")
    base.trainable = False                    # keep it frozen

    img_input  = tf.keras.Input(shape=IMG_SHAPE, dtype=tf.float32, name="image")
    feat_map   = base(img_input, training=False)        # (B, H, W, C)

    # Automatically infer H, W, C from static shape
    h, w, c = feat_map.shape[1], feat_map.shape[2], feat_map.shape[3]
    reshaped = tf.keras.layers.Reshape((1,h*w*c))(feat_map)  # (B, 1,H*W*C)

    return tf.keras.Model(img_input, reshaped, name="img_encoder")                                           # (224,224,3)

model_img = build_image_encoder()

def decode_and_resize(path):
    img = tf.io.read_file(path)
    img = tf.image.decode_jpeg(img, channels=3)
    img = tf.image.resize(img, IMG_SHAPE[:2])
    img = tf.cast(img, tf.float32) / 255.0
    return img

paths = tf.constant(["vqa/rename/350623.jpg", "vqa/rename/56205.jpg"])
images = tf.map_fn(decode_and_resize, paths, fn_output_signature=tf.float32)
v3_output = model_img.predict(images)
v3_output.shape

In [None]:
pad_id = tokenizer.convert_tokens_to_ids('<pad>')
unk_id = tokenizer.convert_tokens_to_ids('<unk>')
bos_id = tokenizer.convert_tokens_to_ids('[CLS]')
eos_id = tokenizer.convert_tokens_to_ids('[MASK]')
sep_id = tokenizer.convert_tokens_to_ids('[SEP]')
img_token_id = tokenizer.convert_tokens_to_ids('$')

def custom_tokenize(input_a, input_b=None):
    tokens = ['$'] * v3_output.shape[1] + ['[CLS]'] + tokenizer.tokenize(input_a) + ['[SEP]']
    
    if input_b:
        tokens += tokenizer.tokenize(input_b) + ['[MASK]']
    
    token_ids = tokenizer.convert_tokens_to_ids(tokens)
    return token_ids

tokens_single = custom_tokenize("A man")
tokens_pair = custom_tokenize("A man", "on moon")
print(tokens_single)
print(tokens_pair)

In [None]:
# Get the full vocab: token → id
vocab_dict = tokenizer.get_vocab()

# To sort it by ID (optional, for better readability)
sorted_vocab = dict(sorted(vocab_dict.items(), key=lambda item: item[1]))
sorted_vocab

In [None]:
tokenized_pairs = [custom_tokenize(x,y) for i,x,y in text_pairs]

lengths = [len(sublist) for sublist in tokenized_pairs]
percentiles = [80, 90, 95, 99]

results = {p: np.percentile(lengths, p) for p in percentiles}
del(tokenized_pairs)
print(results)

# Tokenizing , Embedding and Data Prep

In [None]:
VOCAB_SIZE = tokenizer.vocab_size
BATCH_SIZE = 32
MAX_LEN = int(np.percentile(lengths, 99))

def encode_pair(input_a, input_b=None):
    tokens = ['$'] * v3_output.shape[1] + ['[CLS]'] + tokenizer.tokenize(input_a) + ['[SEP]']
    if input_b:
        tokens += tokenizer.tokenize(input_b) + ['[MASK]']
    
    token_ids = tokenizer.convert_tokens_to_ids(tokens)

    # Pad to MAX_LEN
    if len(token_ids) < MAX_LEN:
        token_ids += [pad_id] * (MAX_LEN - len(token_ids))
    else:
        token_ids = token_ids[:MAX_LEN]

    return np.array(token_ids, dtype=np.int32)

def encode_example(img_path, text: str, summary: str):
    ids = encode_pair(text, summary)  # shape = [MAX_LEN]
    labels = np.concatenate([ids[1:], [pad_id]])  # shifted right

    # find SEP
    sep_idxs = np.where(labels == sep_id)[0]
    sep_pos = int(sep_idxs[0]) if sep_idxs.size else len(ids)

    # build base mask: 1 only for positions > sep_pos AND not PAD
    positions = np.arange(len(labels))
    loss_mask = (positions > sep_pos).astype(np.float32) * (labels != pad_id).astype(np.float32)

    return img_path, ids, labels, loss_mask

In [None]:
img_path = 'flickr/flickr30k_images/1000092795.jpg'
text = 'hello'
summary = 'good morning'

img_path, input_ids, label_ids, loss_mask = encode_example(img_path, text, summary)

print("  Input IDs :", input_ids)
print("\n Label IDs :", label_ids)
print("\n Loss Mask :", loss_mask)

In [None]:
# ── 0. do *all* tokenisation once ──────────────────────────
triples = [encode_example(img, t, s) for (img, t, s) in text_pairs]   # Python loop, done **once**
img_path, ids, labels, masks = map(lambda k: tf.constant(np.stack(k, 0)),zip(*triples))                      # shapes [N, MAX_LEN]

# ── 1. build the purely-TF dataset ─────────────────────────
ds = (
    tf.data.Dataset.from_tensor_slices({'img_path':img_path,"input_ids": ids, "labels": labels, "loss_mask": masks})
    .shuffle(len(text_pairs))
    .batch(BATCH_SIZE)
    .prefetch(tf.data.AUTOTUNE)
)

del(triples)
del(img_path)
del(ids)
del(labels)
del(masks)

In [None]:
vocab = tokenizer.get_vocab()           # indices 0 … V

id_to_token = {idx: tok for tok, idx in vocab.items()}

# 2) Decoder: drop PADs
def decode_token_ids(token_ids: list[int]) -> str:
    tokens = []
    for tid in token_ids:
        if tid == pad_id:
            continue
        tok = id_to_token.get(tid, '?')
        if tok == '$':
            continue  # Skip the '$' symbol
        if tok.startswith('Ġ'):
            tok = tok[1:]  # Remove the space prefix indicator
            tokens.append(' ' + tok)
        else:
            tokens.append(tok)
    return ' '.join(tokens).strip()

# 3) Inspect one batch from your TF dataset
for batch in ds.take(10):
    pth = batch['img_path'].numpy()
    input_ids = batch['input_ids'].numpy()  # shape (batch, MAX_LEN)
    labels    = batch['labels'].numpy()

    for i, (pth, ids_row, lbl_row) in enumerate(zip(pth, input_ids, labels), start=1):
        print(f"\n🟢 Sample {i}")
        print("  pth: ", pth)
        print("  Input IDs: ", ids_row.tolist())
        print("  Decoded:   ", decode_token_ids(ids_row.tolist()))
        print("  Label IDs: ", lbl_row.tolist())

In [None]:
for batch in ds.take(1):
    input_ids = batch["input_ids"]
    labels    = batch["labels"]
    loss_mask = batch["loss_mask"]
    
j=19
print(loss_mask[j])
print('\n',decode_token_ids(batch["input_ids"][j].numpy()))

# Llama Architecture !

In [None]:
def apply_rope(x, sin, cos):
    """
    x   : (B, h, T, d) even-sized last dim (d must be multiple of 2)
    sin : (T, d//2)     broadcastable
    cos : (T, d//2)
    """
    #This separates each feature vector's dimensions into 2 halves — like real and imaginary parts.
    x_even = x[..., 0::2]      # Get even-dimension values → shape: (B, h, T, d/2)
    x_odd  = x[..., 1::2]      # Get odd-dimension values → shape: (B, h, T, d/2)

    # This is a 2D rotation formula applied to each positional index and head.
    # It "rotates" the embedding vector in its dimensional space based on position.
    x_rot_even =  x_even *  cos - x_odd * sin
    x_rot_odd  =  x_even *  sin + x_odd * cos
    
    # interleave even/odd back together
    x_rot = tf.stack([x_rot_even, x_rot_odd], axis=-1)   # (..., d/2, 2)
    return tf.reshape(x_rot, tf.shape(x))                # (..., d)

def make_sincos(seq_len, dim, base=10000):
    '''
    Returns sin, cos with shape (seq_len, dim//2)
    '''
    pos = tf.cast(tf.range(seq_len), tf.float32)                       # (T,)
    i   = tf.cast(tf.range(0, dim, 2), tf.float32) / dim              # (d/2,)
    theta = pos[:, None] / (base ** i[None, :])                       # (T, d/2)
    return tf.sin(theta), tf.cos(theta)

class MultiHeadAttention(tf.keras.layers.Layer):
    """
    Vanilla multi-head (scaled-dot-product) attention implemented from scratch.

    Args
    ----
    d_model     : int   – total embedding size (must be divisible by num_heads)  
    num_heads   : int   – number of attention heads  
    dropout     : float – dropout on attention weights (0.0 = no dropout)

    Call Signature
    --------------
    output, attn_scores = mha(
        query,                     # (B, T_q, d_model)
        value=None,                # (B, T_v, d_model)  – defaults to query
        key=None,                  # (B, T_k, d_model)  – defaults to value
        mask=None,                 # (B, 1, T_q, T_k) or (B, T_q, T_k)
        use_causal_mask=False,     # True → autoregressive causal mask
        training=None
    )
    """
    def __init__(self, d_model, num_heads, dropout=0.0, **kwargs):
        super().__init__(**kwargs)
        if d_model % num_heads != 0:
            raise ValueError(
                f"d_model={d_model} must be divisible by num_heads={num_heads}"
            )

        self.d_model   = d_model
        self.num_heads = num_heads
        self.depth     = d_model // num_heads

        # Linear projections for Q, K, V and final output
        self.wq   = tf.keras.layers.Dense(d_model, use_bias=False)
        self.wk   = tf.keras.layers.Dense(d_model, use_bias=False)
        self.wv   = tf.keras.layers.Dense(d_model, use_bias=False)
        self.wo   = tf.keras.layers.Dense(d_model, use_bias=False)

        self.dropout = tf.keras.layers.Dropout(dropout)

    # ────────────────────────────────────────────────────────────────────────
    # Helpers
    # ────────────────────────────────────────────────────────────────────────
    def _split_heads(self, x, B):
        """
        Reshape (B, T, d_model) → (B, num_heads, T, depth)
        so we can run attention on each head in parallel.
        """
        x = tf.reshape(x, (B, -1, self.num_heads, self.depth))
        return tf.transpose(x, perm=[0, 2, 1, 3])

    @staticmethod
    def _scaled_dot_product_attention(q, k, v, mask, dropout,training=None):
        """
        Core attention: softmax(QKᵀ / √d_k) V
        Returns: (B, h, T_q, depth_v), (B, h, T_q, T_k)
        """
        dk = tf.cast(tf.shape(k)[-1], tf.float32)
        scores = tf.matmul(q, k, transpose_b=True) / tf.math.sqrt(dk)  # (B,h,T_q,T_k)

        if mask is not None:
            # broadcast automatically if mask rank < scores rank
            scores += (mask * -1e9)  # large negative → zero probability

        attn = tf.nn.softmax(scores, axis=-1)
        attn = dropout(attn,training=training)
        output = tf.matmul(attn, v)  # (B,h,T_q,depth)
        return output

    # ────────────────────────────────────────────────────────────────────────
    # Forward pass
    # ────────────────────────────────────────────────────────────────────────
    def call(
        self,
        query,
        value=None,
        key=None,
        mask=None,
        use_causal_mask=False,
        training=None
    ):
        if value is None:
            value = query
        if key is None:
            key = value

        B = tf.shape(query)[0]
        Tq = tf.shape(query)[1]          # sequence length of Q
        Tk = tf.shape(key)[1]

        # 1. Linear projections
        q = self.wq(query)   # (B, T_q, d_model)
        k = self.wk(key)     # (B, T_k, d_model)
        v = self.wv(value)   # (B, T_v, d_model)

        # 2. Reshape for multi-head
        q = self._split_heads(q, B)  # (B, h, T_q, depth)
        k = self._split_heads(k, B)  # (B, h, T_k, depth)
        v = self._split_heads(v, B)  # (B, h, T_v, depth)

        # 3) -----------------  ROTARY  -----------------
        # Build sin/cos for the longest sequence we need this step
        max_len = tf.maximum(Tq, Tk)
        sin, cos = make_sincos(max_len, self.depth)       # depth = d_model / num_heads

        # Slice sin/cos to actual lengths (broadcast works automatically)
        # RoPE modifies Q and K such that their dot product reflects not just content similarity but also relative position.
        q = apply_rope(q, sin[:Tq], cos[:Tq])             # rotate Q
        k = apply_rope(k, sin[:Tk], cos[:Tk])             # rotate K
        # ----------------------------------------------

        # 3. (Optional) Causal mask: block future positions
        if use_causal_mask:
            T_q = tf.shape(q)[2]
            T_k = tf.shape(k)[2]
            causal = 1.0 - tf.linalg.band_part(tf.ones((T_q, T_k)), -1, 0)  # lower-tri  # 1 → masked
            causal = causal[tf.newaxis, tf.newaxis, :, :]  # (1,1,T_q,T_k)
            mask = causal if mask is None else tf.maximum(mask, causal)

        # 4. Scaled dot-product attention
        attn_out = self._scaled_dot_product_attention(
            q, k, v, mask, self.dropout,training=training,
        )  # (B,h,T_q,depth), (B,h,T_q,T_k)

        # 5. Concatenate heads
        attn_out = tf.transpose(attn_out, perm=[0, 2, 1, 3])  # (B,T_q,h,depth)
        attn_out = tf.reshape(attn_out, (B, -1, self.d_model))  # (B,T_q,d_model)

        # 6. Final linear layer
        output = self.wo(attn_out)  # (B,T_q,d_model)

        return output
    
class RMSNorm(tf.keras.layers.Layer):
    def __init__(self, hidden_size, epsilon=1e-8, **kwargs):
        super(RMSNorm, self).__init__(**kwargs)
        self.hidden_size = hidden_size
        self.epsilon = epsilon
        
        # Learnable scale parameter γ (same shape as last dim of input)
        self.scale = self.add_weight(
            name="scale",
            shape=(self.hidden_size,),
            initializer="ones",
            trainable=True
        )

    def call(self, x):
        rms = tf.sqrt(tf.reduce_mean(tf.square(x), axis=-1, keepdims=True) + self.epsilon)
        norm_x = x / rms
        return norm_x * self.scale

class CausalSelfAttention(tf.keras.layers.Layer):

    def __init__(self, d_model, num_heads, dropout=0.0):
        super().__init__()
        self.mha = MultiHeadAttention(d_model=d_model,
                                      num_heads=num_heads,
                                      dropout=dropout)
        self.rmsnorm = RMSNorm(d_model)
        self.add = tf.keras.layers.Add()

    def call(self, x, padding_mask=None, training=None):

        rms_x1 = self.rmsnorm(x)

        attn_output = self.mha(
            query=rms_x1, value=rms_x1, key=rms_x1,
            mask=padding_mask,          # may be None
            use_causal_mask=True,
            training=training,
        )
        rms_x1 = self.add([x, attn_output])
        return rms_x1

In [None]:
class SwiGLU(tf.keras.layers.Layer):
    def __init__(self, hidden_dim,factor=4):
        super().__init__()
        self.lin1 = tf.keras.layers.Dense(factor*hidden_dim,use_bias=False)   # W1
        self.lin2 = tf.keras.layers.Dense(hidden_dim,use_bias=False)       # W2

    def call(self, x):
        x_ = self.lin1(x)                          # shape: (..., 4d)
        a, b = tf.split(x_, num_or_size_splits=2, axis=-1)  # split
        gated = a * (b * tf.sigmoid(b))            # SwiGLU: a ⊙ SiLU(b)
        return self.lin2(gated)

class FeedForward(tf.keras.layers.Layer):
    def __init__(self, d_model, dropout_rate=0.1):
        super().__init__()

        self.seq = tf.keras.Sequential(
            [
                SwiGLU(d_model),
                tf.keras.layers.Dropout(dropout_rate),
            ]
        )
        self.rmsnorm = RMSNorm(d_model)

    def call(self, x, training=None):
        y = self.seq(self.rmsnorm(x), training=training)  # pre-norm
        return x + y                                  # residual on raw x
    
class DecoderLayer(tf.keras.layers.Layer):
    def __init__(self, *, d_model, num_heads, dropout_rate=0.1):
        super().__init__()
        self.causal_self_attention = CausalSelfAttention(num_heads=num_heads, d_model=d_model, dropout=dropout_rate)
        self.ffn = FeedForward(d_model)

    def call(self, x, padding_mask=None, training=None):
        x = self.causal_self_attention(x, padding_mask=padding_mask, training=training)
        x = self.ffn(x,training=training)
        return x

In [None]:
class Decoder(tf.keras.layers.Layer):
    def __init__(
        self, *, num_layers, d_model, num_heads, dropout_rate=0.1,pad_token_id = pad_id):
        super(Decoder, self).__init__()

        self.d_model = d_model
        self.num_layers = num_layers
        self.pad_token_id = pad_token_id

        self.token_embedding = token_model
        self.img_embedding = build_image_encoder()

        self.img_projection = tf.keras.layers.Dense(d_model)
        # self.text_projection = tf.keras.layers.Dense(d_model) # not applicable if d_model != 128
        
        self.dropout = tf.keras.layers.Dropout(dropout_rate)
        
        self.dec_layers = [DecoderLayer(d_model=d_model, num_heads=num_heads, dropout_rate=dropout_rate)
                           for _ in range(num_layers)]

        self.last_attn_scores = None

    def call(self, inputs, training=None):
        pad_mask = tf.cast(tf.equal(inputs[1], self.pad_token_id), tf.float32)[:, tf.newaxis, tf.newaxis, :]

        img = self.img_embedding(inputs[0],training=False)
        img = self.img_projection(img)

        text = self.token_embedding(inputs[1])
        # text = self.text_projection(text)

        # Make sure img_emb has the same dtype as text_emb
        img = tf.cast(img, dtype=text.dtype)

        # Replace text[:, 0, :] with img_emb
        x = tf.concat([img, text[:, v3_output.shape[1]:, :]], axis=1)
        x = self.dropout(x, training=training)

        for layer in self.dec_layers:
            x = layer(x, padding_mask=pad_mask, training=training)

        return x
    
class Transformer(tf.keras.Model):
    def __init__(self, *, num_layers, d_model, num_heads, input_vocab_size, dropout_rate=0.1):
        super().__init__()

        self.decoder = Decoder(
            num_layers=num_layers,
            d_model=d_model,
            num_heads=num_heads,
            dropout_rate=dropout_rate,
        )

        self.rmsnorm = RMSNorm(d_model)
        self.final_layer = tf.keras.layers.Dense(input_vocab_size)

    def call(self, inputs,training=False):
        # To use a Keras model with `.fit` you must pass all your inputs in the
        # first argument.
        x = inputs

        x = self.decoder(x,training=training)  # (batch_size, target_len, d_model)

        # Final linear layer output.
        x = self.rmsnorm(x)
        logits = self.final_layer(x)  # (batch_size, target_len, target_vocab_size)

        try:
            # Drop the keras mask, so it doesn't scale the losses/metrics.
            # b/250038731
            del logits._keras_mask
        except AttributeError:
            pass

        # Return the final output and the attention weights.
        return logits

In [None]:
num_layers = 2
d_model = 128   # if d_model != 128 then text_linear_projection
num_heads = 8
dropout_rate = 0.1
EPOCHS = 3

model = Transformer(
    num_layers=num_layers,
    d_model=d_model,
    num_heads=num_heads,
    input_vocab_size=VOCAB_SIZE,
    dropout_rate=0.1
)

dummy_path   = tf.constant(["vqa/rename/210032.jpg"])   # (1,)  string
images = tf.map_fn(decode_and_resize, dummy_path, fn_output_signature=tf.float32)
dummy_tokens = tf.zeros((1, MAX_LEN), dtype=tf.int32)
_ = model((images, dummy_tokens))              
model.summary()

In [None]:
WARM_FRAC   = 0.1        # keep the fast ramp-up
DECAY_RATE  = 4
LR_FLOOR    = 1e-6
LR_PEAK_DESIRED = 8e-4     # choose 8e-4 or 9e-4

# pre-compute the scale that gives that peak
num_steps     = EPOCHS * len(text_pairs) // BATCH_SIZE
warmup_steps  = int(num_steps * WARM_FRAC)
current_peak  = 1.0 / tf.sqrt(tf.cast(d_model * warmup_steps, tf.float32))
LR_SCALE      = LR_PEAK_DESIRED / current_peak.numpy()

@tf.keras.utils.register_keras_serializable()   # so it can round-trip in SavedModel/H5
class CustomSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
    def __init__(self, d_model, total_steps=num_steps,
                 warmup_frac=WARM_FRAC, decay_rate=DECAY_RATE,
                 lr_scale=LR_SCALE):
        super().__init__()
        self.d_model      = tf.cast(d_model, tf.float32)
        self.warmup_steps = tf.cast(int(total_steps * warmup_frac), tf.float32)
        self.decay_rate   = decay_rate
        self.decay_steps  = tf.cast(total_steps, tf.float32)
        self.lr_scale     = tf.cast(lr_scale, tf.float32)

    def __call__(self, step):
        step  = tf.cast(step, tf.float32)
        arg1  = tf.math.rsqrt(step)
        arg2  = step * tf.math.pow(self.warmup_steps, -1.5)
        warm  = tf.math.rsqrt(self.d_model) * tf.math.minimum(arg1, arg2)
        decay = tf.math.exp(-self.decay_rate *
                            tf.maximum(step - self.warmup_steps, 0.) /
                            self.decay_steps)
        lr = warm * decay * self.lr_scale 
        return tf.maximum(lr, LR_FLOOR)

    # ----------  NEW  ----------
    def get_config(self):
        return {
            "d_model":      int(self.d_model.numpy()),   # cast back to Python types
            "total_steps":  int(self.decay_steps.numpy()),
            "warmup_frac":  float(self.warmup_steps.numpy() / self.decay_steps.numpy()),
            "decay_rate":   self.decay_rate,
            "lr_scale":     float(self.lr_scale.numpy()),
        }

    @classmethod
    def from_config(cls, config):
        return cls(**config)
    
learning_rate = CustomSchedule(d_model)

optimizer = tf.keras.optimizers.Adam(learning_rate, beta_1=0.9, beta_2=0.98, epsilon=1e-9, clipnorm=1.0)

temp_learning_rate_schedule = CustomSchedule(d_model)

plt.plot(temp_learning_rate_schedule(tf.range(num_steps, dtype=tf.float32)))
plt.ylabel("Learning Rate")
plt.xlabel("Train Step")

print(temp_learning_rate_schedule(tf.range(num_steps, dtype=tf.float32)))

In [None]:
ds_for_fit = ds.map(
    lambda b: (
        (tf.map_fn(decode_and_resize, b['img_path'], fn_output_signature=tf.float32), b["input_ids"]),
        b["labels"],
        b["loss_mask"]
    ),
    num_parallel_calls=tf.data.AUTOTUNE
)

del(ds)

# 3) Compile with a standard sparse‐CE loss and let Keras use sample weights
model.compile(
    optimizer=optimizer,
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction="none"),
    metrics=["sparse_categorical_accuracy"]
)

# Build callbacks
callbacks = [
    EarlyStopping(monitor="loss", patience=3,
                  restore_best_weights=True, verbose=1),
    ModelCheckpoint(
        filepath="best_summary.keras",        # or "best_summary.h5"
        monitor="loss",
        save_best_only=True,
        verbose=1             # full model (weights + optimizer + LR schedule)
    )
]

# 4) Fit!  Keras will print epoch/step progress by default
history = model.fit(
    ds_for_fit,
    epochs=EPOCHS,
    callbacks=callbacks,
    verbose=1    # 1 = progress bar, loss & acc per epoch
)

In [None]:
def encode(text: str, max_len: int = MAX_LEN) -> tf.Tensor:
    """
    Returns [BOS] dialogue [EOS] *without* any padding.
    """
    ids = custom_tokenize(text)               # np padding
    return tf.constant(ids, tf.int32)              # 1-D tensor


# ── 2.  Greedy decoding loop  ─────────────────────────────────────────
def generate_answer(image_path: str,
                    question:   str,
                    max_new_tokens: int = 30) -> str:
    
    img_token = tf.constant([image_path])
    img_token = tf.map_fn(decode_and_resize, img_token, fn_output_signature=tf.float32)
    
    text_tokens = encode(question)
    text_tokens = text_tokens[tf.newaxis, :]   

    # 2-C.  Autoregressive loop
    for _ in range(max_new_tokens):
        # logits : (1, cur_len, vocab)
        logits = model.predict((img_token, text_tokens),verbose=False)[:, -1, :] / 0.5        # last position
        next_id = tf.random.categorical(logits, 1, dtype=tf.int32)

        # append
        text_tokens = tf.concat([text_tokens, next_id], 1)

        # stop on EOS
        if int(next_id[0]) == eos_id or text_tokens.shape[1] >= MAX_LEN:
            break

    generated = text_tokens.numpy().tolist()[0] 

    imgbgr = cv2.imread(image_path)
    imgrgb = cv2.cvtColor(imgbgr,cv2.COLOR_BGR2RGB)
    plt.imshow(imgrgb)
    plt.show()

    return tokenizer.decode(generated)   

In [None]:
img   = "vqa/rename/13291.jpg"
ques  = "what are they getting ready to do?"
print(generate_answer(img, ques))

In [None]:
question_pool = [
    "What is present in the image?",
    "How would you describe the background setting?",
    "What is the central subject or point of focus?",
    "What is happening in the image?",
    "Where might this photo have been taken?",
    "Describe the image.",
    "Are there people in the picture?",
    "Is there an animal in the picture?",
    "What objects are visible in the image?",
    "What activity or event is occurring in the image?",
    "Can you describe the colors present in the image?",
    "What do you think the mood or atmosphere of the image is?",
    "What draws your attention most in the image?",
    "Does the image appear to be candid or posed?",
    "Is there any text or signage visible in the image?",
    "Are there any buildings or structures visible?",
    "Does the image convey any emotions or feelings?",
    "Are there any artistic or stylistic elements in the photo?",
    "If you could give this image a title, what would it be?",
    'What is unique about the image?',
    'What can you tell?',
]


In [None]:
img   = "vqa/rename/74.jpg"
ques  = "where is the dog ?"
print(generate_answer(img, ques))

In [None]:
img   = "vqa/rename/7519.jpg"
ques  = "what is the color of the shirt?"
print(generate_answer(img, ques))

In [None]:
for j in range(20):
    print('\n',j)
    img   = df2['img_path'][j]
    ques  = df2['input_text'][j]
    print(generate_answer(img, ques))