In [1]:
import os
import json
os.environ["KERAS_BACKEND"] = "tensorflow"

In [2]:
import tensorflow as tf
import keras
import keras_hub
from keras._tf_keras.keras import layers
import nltk
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction

2025-08-11 13:14:39.439943: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1754898279.458405   12785 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1754898279.463734   12785 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1754898279.480297   12785 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1754898279.480337   12785 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1754898279.480340   12785 computation_placer.cc:177] computation placer alr

In [None]:
nltk.download('punkt')

In [12]:
AUTOTUNE = tf.data.AUTOTUNE
IMG_DIR= os.getcwd() + "/data/images"
CAPTION_FILE = os.getcwd() + "/data/captions.json"
IMG_SIZE = (224, 224)
BATCH_SIZE = 32
EPOCHS = 3
MAX_TOKENS = 5000
SEQ_LEN = 32
VOCAB_SIZE = 10_000

In [4]:
def load_data():
    with open(CAPTION_FILE, 'r') as f:
        captions_data = json.load(f)

    image_id_to_caption = {}
    for data in captions_data:
        img_id = data["image_id"]
        caption = data["caption"]
        if img_id not in image_id_to_caption:
            image_id_to_caption[img_id] = []
        image_id_to_caption[img_id].append(caption)

    image_caption_pairs = []
    for img_id, captions in image_id_to_caption.items():
        filename = f"COCO_train2014_{img_id:012d}.jpg"
        img_path = os.path.join(IMG_DIR, filename)
        for caption in captions:
            image_caption_pairs.append((img_path, f"<START> {caption} <END>"))
    
    return image_caption_pairs

In [5]:
image_caption_pairs = load_data()
print(f"Total pairs: {len(image_caption_pairs)}")

Total pairs: 93950


In [6]:
image_paths, captions = zip(*load_data())
dataset = tf.data.Dataset.from_tensor_slices((list(image_paths), list(captions)))
dataset = dataset.shuffle(len(image_caption_pairs), reshuffle_each_iteration=False)
dataset_size = tf.data.experimental.cardinality(dataset).numpy()
train_size = int(0.8 * dataset_size)
train_ds = dataset.take(train_size)
val_ds = dataset.skip(train_size)

I0000 00:00:1754746749.759671    4727 gpu_device.cc:2019] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 3584 MB memory:  -> device: 0, name: NVIDIA GeForce RTX 3060 Laptop GPU, pci bus id: 0000:01:00.0, compute capability: 8.6


In [7]:
tokenizer = layers.TextVectorization(
    max_tokens=VOCAB_SIZE,
    output_sequence_length=SEQ_LEN,
    standardize="lower_and_strip_punctuation",
    output_mode="int"
)
text_data = train_ds.map(lambda img, cap: cap, num_parallel_calls=tf.data.AUTOTUNE)
tokenizer.adapt(text_data)

2025-08-09 19:27:54.695748: I tensorflow/core/framework/local_rendezvous.cc:407] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


In [8]:
vocab = tokenizer.get_vocabulary()
with open("vocab.txt", "w", encoding="utf-8") as f:
    for token in vocab:
        f.write(token + "\n")

with open("tokenizer_meta.json", "w") as f:
    json.dump({"max_len": 30, "max_tokens": len(vocab)}, f)

In [9]:
def preprocess_image_caption(image_path, caption):
    img = tf.io.read_file(image_path)
    img = tf.image.decode_jpeg(img, channels=3)
    img = tf.image.resize(img, (224, 224))
    img = tf.cast(img, tf.float32) / 255.0
    cap_tokens = tokenizer(caption)
    return img, cap_tokens

In [10]:
train_ds = train_ds.map(preprocess_image_caption, num_parallel_calls=tf.data.AUTOTUNE).batch(32).prefetch(tf.data.AUTOTUNE)
val_ds   = val_ds.map(preprocess_image_caption, num_parallel_calls=tf.data.AUTOTUNE).batch(32).prefetch(tf.data.AUTOTUNE)

In [None]:
# Encoder (ViT)
vit_backbone = keras_hub.models.Backbone.from_preset("vit_base_patch16_224_imagenet")
vit_backbone.trainable = False

image_input = keras.Input(shape=(224, 224, 3), name="image")
encoder_outputs = vit_backbone(image_input)
#encoder_features = layers.GlobalAveragePooling1D()(encoder_outputs)
#encoder_features = layers.Dense(512, activation="relu")(encoder_features)

# Decoder (GRU)
caption_input = keras.Input(shape=(None,), dtype=tf.int32, name="caption")
x = layers.Embedding(input_dim=VOCAB_SIZE, output_dim=512, mask_zero=True)(caption_input)
attn_out = layers.Attention()([x, encoder_outputs])
x = layers.Concatenate()([x, attn_out])
x = layers.GRU(512, return_sequences=True)(x)
output = layers.Dense(VOCAB_SIZE, activation="softmax")(x)

model = keras.Model(inputs=[image_input, caption_input], outputs=output)

In [None]:
index_to_word = {i: w for i, w in enumerate(tokenizer.get_vocabulary())}
start_token_id = tokenizer("<START>").numpy()[0]
end_token_id = tokenizer("<END>").numpy()[0]

class BLEUCallback(keras.callbacks.Callback):
    def __init__(self, val_dataset, max_len=30):
        super().__init__()
        self.val_dataset = val_dataset
        self.max_len = max_len

    def decode_image(self, img_tensor):
        """Greedy decoding for image caption."""
        dec_input = tf.expand_dims([start_token_id], 0)
        result = []

        for _ in range(self.max_len):
            preds = self.model([img_tensor, dec_input], training=False)
            pred_id = tf.argmax(preds[:, -1, :], axis=-1).numpy()[0]
            if pred_id == end_token_id:
                break
            result.append(index_to_word.get(pred_id, ""))
            dec_input = tf.concat([dec_input, tf.expand_dims([pred_id], 0)], axis=1)

        return " ".join(result)

    def on_epoch_end(self, epoch, logs=None):
        smoothie = SmoothingFunction().method4
        bleu_scores = []

        for img_batch, cap_batch in self.val_dataset.take(10):  # limit for speed
            for i in range(len(img_batch)):
                img_tensor = tf.expand_dims(img_batch[i], 0)
                pred_caption = self.decode_image(img_tensor)

                gt_tokens = cap_batch[i].numpy()
                gt_words = [index_to_word.get(idx, "") for idx in gt_tokens
                            if idx not in [0, start_token_id, end_token_id]]

                reference = [nltk.word_tokenize(" ".join(gt_words))]
                candidate = nltk.word_tokenize(pred_caption)

                bleu = sentence_bleu(reference, candidate, smoothing_function=smoothie)
                bleu_scores.append(bleu)

        avg_bleu = np.mean(bleu_scores)
        logs["val_bleu"] = avg_bleu
        print(f"\nEpoch {epoch+1} — Val BLEU: {avg_bleu:.4f}")

In [12]:
model.compile(
    optimizer=keras.optimizers.Adam(learning_rate=1e-4),
    loss=keras.losses.SparseCategoricalCrossentropy(),
    metrics=["accuracy"]
)

In [13]:
model.summary()

In [14]:
os.makedirs("checkpoints", exist_ok=True)
checkpoint = keras.callbacks.ModelCheckpoint(
    "checkpoints/best_model.keras",
    save_best_only=True,
    monitor="val_loss",
    mode="min"
)

In [16]:
model.fit(
    train_ds.map(lambda img, cap: ((img, cap[:, :-1]), cap[:, 1:])),
    validation_data=val_ds.map(lambda img, cap: ((img, cap[:, :-1]), cap[:, 1:])),
    epochs=EPOCHS,
    callbacks=[checkpoint]
)

Epoch 1/3


I0000 00:00:1754747885.835610    5470 cuda_dnn.cc:529] Loaded cuDNN version 90300


[1m2349/2349[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1186s[0m 501ms/step - accuracy: 0.1649 - loss: 4.7727 - val_accuracy: 0.1742 - val_loss: 3.0523
Epoch 2/3
[1m2349/2349[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1291s[0m 550ms/step - accuracy: 0.1790 - loss: 2.9458 - val_accuracy: 0.1877 - val_loss: 2.7479
Epoch 3/3
[1m2349/2349[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1201s[0m 511ms/step - accuracy: 0.1910 - loss: 2.6448 - val_accuracy: 0.1941 - val_loss: 2.6092


<keras.src.callbacks.history.History at 0x7dff7a4ad8d0>

## Inference:

In [1]:
MODEL_PATH = "checkpoints/best_model.keras"
VOCAB_FILE = "vocab.txt"
MAX_LEN = 30

In [2]:
def load_vectorizer_from_vocab(vocab_file, max_len, max_tokens=None):
    vocab = [line.rstrip("\n") for line in open(vocab_file, "r", encoding="utf-8")]
    # create new TextVectorization with same params as training
    vectorizer = layers.TextVectorization(
        max_tokens=max_tokens or len(vocab),
        output_sequence_length=max_len,
        standardize="lower_and_strip_punctuation",
        output_mode="int"
    )
    # set vocabulary so it doesn't require adapt()
    vectorizer.set_vocabulary(vocab)
    return vectorizer

In [3]:
def preprocess_image_from_path(image_path):
    img = tf.io.read_file(image_path)
    img = tf.image.decode_image(img, channels=3)
    img = tf.image.resize(img, IMG_SIZE)
    img = tf.cast(img, tf.float32) / 255.0
    return img  # float32 H,W,3 scaled to [0,1]

In [7]:
model = keras.models.load_model(MODEL_PATH, compile=False)

I0000 00:00:1754759303.392367   53933 gpu_device.cc:2019] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 3584 MB memory:  -> device: 0, name: NVIDIA GeForce RTX 3060 Laptop GPU, pci bus id: 0000:01:00.0, compute capability: 8.6
2025-08-09 22:38:36.005729: W external/local_xla/xla/tsl/framework/cpu_allocator_impl.cc:83] Allocation of 9437184 exceeds 10% of free system memory.
2025-08-09 22:38:36.799607: W external/local_xla/xla/tsl/framework/cpu_allocator_impl.cc:83] Allocation of 20480000 exceeds 10% of free system memory.


In [8]:
vectorizer = load_vectorizer_from_vocab(VOCAB_FILE, max_len=MAX_LEN, max_tokens=None)
vocab = vectorizer.get_vocabulary()
id_to_token = {i: t for i, t in enumerate(vocab)}

In [9]:
start_candidates = ["<start>", "start"]
end_candidates = ["<end>", "end"]
start_id = None
end_id = None
for w, idx in zip(vocab, range(len(vocab))):
    lw = w.lower()
    if start_id is None and lw in start_candidates:
        start_id = idx
    if end_id is None and lw in end_candidates:
        end_id = idx
# Fallbacks if your tokens are different
if start_id is None:
    # If you did not use explicit tokens, you can treat generation as starting from an "empty" prompt.
    # Here we fall back to using token id for first non-zero token (less ideal).
    start_id = 1
if end_id is None:
    end_id = None  # we'll rely on max length

In [19]:
def generate_caption_greedy(image_path, model, vectorizer, max_len=MAX_LEN):
    img = preprocess_image_from_path(image_path)
    img = tf.expand_dims(img, 0)  # batch dim

    # initialize with start token
    decoded = [start_id]
    for i in range(max_len - 1):
        # prepare decoder input (batch, seq_len)
        dec_input = tf.expand_dims(tf.constant(decoded, dtype=tf.int32), 0)  # shape (1, cur_len)
        # model expects (image_batch, caption_input). Provide current decoder tokens.
        preds = model([img, dec_input], training=False)  # (1, seq_len, vocab)
        # take logits at last timestep
        logits = preds[0, -1, :]  # (vocab,)
        next_id = int(tf.argmax(logits).numpy())
        decoded.append(next_id)
        if (end_id is not None) and (next_id == end_id):
            break

    # Convert token ids to tokens, strip start & end
    # Remove the first token (start) and any end token and padding (id 0)
    token_ids = decoded[1:]  # skip start
    words = []
    for tid in token_ids:
        if tid == 0:
            continue
        tok = id_to_token.get(tid, "")
        if end_id is not None and tid == end_id:
            break
        words.append(tok)
    caption = " ".join(w for w in words if w not in ("<start>", "<end>"))
    # basic cleanup: collapse multiple spaces
    caption = " ".join(caption.split())
    return caption

# ---- Example usage ----
image_path = "COCO_train2014_000000000094.jpg"
caption = generate_caption_greedy(image_path, model, vectorizer)
print("Generated caption:", caption)

Generated caption: a street sign with a traffic sign on the street
