imports

In [None]:
import tensorflow as tf
from tensorflow.keras.applications import ViT
from tensorflow.keras.layers import LSTM, Dense, Embedding
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from tensorflow_text import Tokenizer


tokenizer and model

In [None]:
tokenizer = Tokenizer(num_words=10000)  # Adjust vocabulary size as needed
vit_model = ViT(model_name="google/vit-base-patch32", include_top=False)


image input

In [None]:
def prepare_image_input(image):
    image = tf.keras.applications.imagenet_utils.preprocess_input(image)  # Preprocess for ViT
    image_features = vit_model(image)[0]  # Extract features
    return image_features


dataset for training

In [None]:
def create_dataset(images, captions):
    dataset = tf.data.Dataset.from_tensor_slices((images, captions))
    dataset = dataset.map(lambda x, y: (prepare_image_input(x), tokenizer(y)))
    # Apply other transformations as needed (e.g., shuffling, batching)
    return dataset


training model build

In [None]:
encoder_inputs = vit_model.input
encoder_outputs = vit_model.output
decoder_inputs = tf.keras.layers.Input(shape=(None,))
decoder_embedding = Embedding(input_dim=tokenizer.num_words, output_dim=embedding_dim)(decoder_inputs)
decoder_lstm = LSTM(units=512, return_sequences=True)(decoder_embedding, initial_state=encoder_outputs)
decoder_outputs = Dense(tokenizer.num_words, activation="softmax")(decoder_lstm)
model = Model([encoder_inputs, decoder_inputs], decoder_outputs)


train the model

In [None]:
optimizer = Adam(learning_rate=0.001)
loss = tf.keras.losses.CategoricalCrossentropy()
model.compile(optimizer=optimizer, loss=loss)

# Load your image and caption data
images, captions = ...  # Load your dataset
dataset = create_dataset(images, captions)

model.fit(dataset, epochs=10)  # Adjust epochs as needed


generate captions

In [None]:
def generate_caption(image):
    image_features = prepare_image_input(image)
    decoder_input = tf.expand_dims([tokenizer.word_index["<start>"]], 0)
    for i in range(max_caption_length):
        predictions = model.predict([image_features, decoder_input])
        predicted_id = tf.argmax(predictions[0, -1, :])
        predicted_word = tokenizer.index_word[predicted_id.numpy()]
        if predicted_word == "<end>":
            break
        decoder_input = tf.expand_dims([predicted_id], 0)
    caption = tokenizer.sequences_to_texts(decoder_input.numpy())[0]
    return caption
