In [None]:
import tensorflow as tf
import tensorflow_datasets as tfds

# Load the MS COCO dataset
data, info = tfds.load("coco/2014", with_info=True)

# Preprocess the captions
tokenizer = tfds.deprecated.text.SubwordTextEncoder.build_from_corpus(
    (caption.numpy() for example in data["train"] for caption in example["captions"]),
    target_vocab_size=2**13)

In [None]:
def encode_caption(caption):
    """Encode the caption with the tokenizer"""
    caption = "<start> " + caption + " <end>"
    return tokenizer.encode(caption)

def preprocess(example):
    """Preprocess an example (image and captions)"""
    image = tf.image.resize(example["image"], (224, 224))
    caption = tf.random.shuffle(example["captions"])[0]
    caption = encode_caption(caption)
    return image, caption

In [None]:
# Prepare the dataset
train_data = data["train"].map(preprocess).batch(32)

# Define the model architecture
image_model = tf.keras.applications.InceptionV3(include_top=False, weights="imagenet")
image_features_extractor = tf.keras.Model(image_model.input, image_model.layers[-1].output)
caption_model = tf.keras.Sequential([
    tf.keras.layers.Embedding(tokenizer.vocab_size, 256),
    tf.keras.layers.LSTM(256),
    tf.keras.layers.Dense(256, activation="relu"),
    tf.keras.layers.Dense(tokenizer.vocab_size, activation="softmax")
])

# Define the loss function and optimizer
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy()
optimizer = tf.keras.optimizers.Adam()

# Define the training step
@tf.function
def train_step(images, captions):
    with tf.GradientTape() as tape:
        # Extract the image features
        features = image_features_extractor(images)
        # Initialize the LSTM state with zeros
        state = caption_model.layers[0].get_initial_state(batch_size=features.shape[0])
        # Feed the image features and the captions to the caption model
        inputs = tf.concat([tf.expand_dims(features, 1), captions[:, :-1]], axis=1)
        logits = caption_model(inputs, initial_state=state)
        # Compute the loss
        loss = loss_fn(captions[:, 1:], logits)
    # Update the model parameters
    gradients = tape.gradient(loss, caption_model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, caption_model.trainable_variables))
    return loss

In [None]:
# Train the model
for epoch in range(10):
    for images, captions in train_data:
        loss = train_step(images, captions)
        print(f"Epoch {epoch+1}, Loss {loss.numpy():.4f}")

In [None]:
# Use the model for zero-shot image captioning
def generate_caption(image, prompt):
    """Generate a caption for the image with the given prompt"""
    prompt = encode_caption(prompt)
    # Extract the image features
    features = image_features_extractor(image)
    # Initialize the LSTM state with zeros
    state = caption_model.layers[0].get_initial_state(batch_size=1)
    # Feed the image features and the prompt to the caption model
    inputs = tf.concat([tf.expand_dims(features, 1), prompt[:, :-1]], axis=1)
    logits = caption_model(inputs, initial_state=state)
    # Decode the predicted caption
    predicted_caption = tokenizer.decode([tf.argmax(logits[0, i]).numpy() for i in range(logits.shape[1])])
    return predicted_caption