In [None]:
import tensorflow as tf
from tensorflow.keras import Model, layers
from transformers import TFGPT2Model, GPT2Tokenizer
import numpy as np
from PIL import Image
import os

class ImageFeatureExtractor(layers.Layer):
    def __init__(self, output_dim):
        super(ImageFeatureExtractor, self).__init__()
        base_model = tf.keras.applications.MobileNetV2(
            include_top=False,
            weights='imagenet',
            input_shape=(224, 224, 3)
        )
        base_model.trainable = False
        self.cnn = base_model
        self.global_pool = layers.GlobalAveragePooling2D()
        self.projection = layers.Dense(output_dim, activation='relu')

    def call(self, images):
        x = self.cnn(images)
        x = self.global_pool(x)
        return self.projection(x)

class ProjectionLayer(layers.Layer):
    def __init__(self, embedding_dim):
        super(ProjectionLayer, self).__init__()
        self.dense = layers.Dense(embedding_dim)

    def call(self, image_features):
        return self.dense(image_features)

class ImageCaptioningModel(Model):
    def __init__(self, max_length=50, vocab_size=50257):
        super(ImageCaptioningModel, self).__init__()
        self.gpt2 = TFGPT2Model.from_pretrained('gpt2')
        self.gpt2.trainable = False
        self.image_encoder = ImageFeatureExtractor(
            output_dim=self.gpt2.config.hidden_size
        )
        self.projection = ProjectionLayer(self.gpt2.config.hidden_size)
        self.output_layer = layers.Dense(vocab_size, activation='softmax')
        self.max_length = max_length

    def call(self, inputs):
        images, text_tokens = inputs
        image_features = self.image_encoder(images)
        projected_features = self.projection(image_features)
        text_outputs = self.gpt2(text_tokens, return_dict=True)
        text_features = text_outputs.last_hidden_state
        expanded_features = tf.expand_dims(projected_features, axis=1)
        expanded_features = tf.tile(
            expanded_features,
            [1, tf.shape(text_features)[1], 1]
        )
        combined_features = text_features + expanded_features
        outputs = self.output_layer(combined_features)
        return outputs

def preprocess_image(image_path):
    """Load and preprocess a single image."""
    img = Image.open(image_path)
    img = img.convert('RGB')
    img = img.resize((224, 224))
    img_array = tf.keras.preprocessing.image.img_to_array(img)
    img_array = tf.keras.applications.mobilenet_v2.preprocess_input(img_array)
    return img_array

def generate_caption(model, image_path, max_length=50):
    """Generate a caption for a single image."""
    tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
    tokenizer.pad_token = tokenizer.eos_token

    img = preprocess_image(image_path)
    img = tf.expand_dims(img, 0)
    current_tokens = tf.convert_to_tensor([[tokenizer.bos_token_id]], dtype=tf.int32)
    generated_caption = []

    for _ in range(max_length):
        predictions = model((img, current_tokens))
        next_token = tf.argmax(predictions[:, -1, :], axis=-1)
        token_id = int(next_token.numpy()[0])
        generated_caption.append(token_id)

        if token_id == tokenizer.eos_token_id:
            break

        current_tokens = tf.convert_to_tensor([generated_caption], dtype=tf.int32)

    caption = tokenizer.decode(generated_caption, skip_special_tokens=True)
    return caption

# Example usage
def test_model(checkpoint_path, image_path):
    """Test the model on a single image."""
    # Initialize model
    model = ImageCaptioningModel(max_length=50)
    
    # Create a dummy input to build the model
    dummy_image = tf.zeros((1, 224, 224, 3))
    dummy_tokens = tf.zeros((1, 1), dtype=tf.int32)
    _ = model((dummy_image, dummy_tokens))
    
    # Load weights
    model.load_weights(checkpoint_path)
    
    # Generate caption
    caption = generate_caption(model, image_path)
    return caption

2025-01-15 20:15:13.573240: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1736968513.615790    4399 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1736968513.628326    4399 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-01-15 20:15:13.736987: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
  from .autonotebook import tqdm as notebook_tqdm
I0000 00:00:1736968516.275669    4399 gpu_device.cc:2022] Created device /j

Generated caption: AA airplane flying in the sky.


In [6]:
# Replace these paths with your actual paths
model_path = "model/image_captioning_model_epoch_from_train10.weights.h5"  # Update with your model path
image_path = "archive/coco2017/test2017/000000581823.jpg"  # Update with your test image path

caption = test_model(model_path, image_path)
print(f"Generated caption: {caption}")

All PyTorch model weights were used when initializing TFGPT2Model.

All the weights of TFGPT2Model were initialized from the PyTorch model.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFGPT2Model for predictions without further training.


Generated caption: AA traffic light on a city street.
