# Multimodal Image Captioning using ResNet50 and FLAN-T5


This notebook demonstrates how to create a multimodal model that combines image features extracted from ResNet50 with text generation using the FLAN-T5 model. The model is trained to generate captions for images using the Flickr8k dataset.


In [None]:
# Install necessary libraries for the project
!pip install 'keras==3.2' datasets tensorflow

## Step 1: Load and Preprocess the Data
In this step, we load the data from the Flickr8k dataset. We'll preprocess it to feed both image and text data into our model.

In [None]:
# Import necessary libraries for data loading and processing
# The flickr8k dataset is being loaded from the Hugging Face datasets library
# We're only using a small subset of 10 images for demonstration purposes

from datasets import load_dataset
import tensorflow as tf
from transformers import AutoTokenizer

# Load the flickr8k dataset
dataset = load_dataset("Naveengo/flickr8k", split='train[:10]')


## Step 2: Define the Model Architecture
Here, we define the multimodal model architecture. We will use a pre-trained ResNet50 model for image features and a FLAN-T5 model for generating text captions.

In [None]:
# Define the architecture of the multimodal model
# We start with a pretrained ResNet50 model for image feature extraction
# The image input shape is set to (224, 224, 3), which is common for pre-trained models
# The top layers are removed to allow for custom layers, and we freeze the layers to avoid training them
from tensorflow.keras.applications import ResNet50
from tensorflow.keras.layers import GlobalAveragePooling2D, Input, Dense, Concatenate, Lambda
from tensorflow.keras.models import Model
from transformers import TFAutoModelForSeq2SeqLM
import tensorflow as tf

# Load ResNet50 without the top layers (for feature extraction)
image_model = ResNet50(include_top=False, weights='imagenet', input_shape=(224, 224, 3))
image_model.trainable = False  # Freeze the layers

# Define image input
image_input = Input(shape=(224, 224, 3), name='image_input')
image_features = image_model(image_input)
image_features = GlobalAveragePooling2D()(image_features)

# Load the FLAN-T5 model
text_model = TFAutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-base")

# Define text inputs
input_ids = Input(shape=(128,), dtype=tf.int32, name='input_ids')
attention_mask = Input(shape=(128,), dtype=tf.int32, name='attention_mask')

# Wrap the encoder in a Lambda layer with output shape specified
encoder_outputs = Lambda(
    lambda inputs: text_model.encoder(input_ids=inputs[0], attention_mask=inputs[1]).last_hidden_state,
    output_shape=(128, 768)
)([input_ids, attention_mask])


# Expand image features to match the sequence length of the text (128)
image_features_expanded = Lambda(lambda x: tf.expand_dims(x, 1), output_shape=(1, 2048))(image_features)
image_features_expanded = Lambda(lambda x: tf.tile(x, [1, 128, 1]), output_shape=(128, 2048))(image_features_expanded)
image_features_expanded = Dense(768)(image_features_expanded)

# Combine image features with text encoder outputs
combined_features = Concatenate(axis=-1)([encoder_outputs, image_features_expanded])  # Resulting shape should be (128, 1536)

projected_features = Dense(768)(combined_features)  # Shape: (None, 128, 768)


# Define the decoder inputs
decoder_input_ids = Input(shape=(128,), dtype=tf.int32, name='decoder_input_ids')
decoder_attention_mask = Input(shape=(128,), dtype=tf.int32, name='decoder_attention_mask')

decoder_outputs = Lambda(
    lambda inputs: text_model.decoder(input_ids=inputs[0], attention_mask=inputs[1], encoder_hidden_states=inputs[2]).last_hidden_state
)([decoder_input_ids, decoder_attention_mask, projected_features])


# The decoder outputs are usually a tuple, so we get the first element
decoder_last_hidden_state = decoder_outputs.last_hidden_state if isinstance(decoder_outputs, tuple) else decoder_outputs

# Final dense layer to generate logits for the vocabulary
output = Lambda(lambda x: text_model.lm_head(x), output_shape=(128, text_model.config.vocab_size))(decoder_last_hidden_state)

# Define the complete model
multimodal_model = Model(
    inputs=[image_input, input_ids, attention_mask, decoder_input_ids, decoder_attention_mask],
    outputs=output
)

# Compile the model
multimodal_model.compile(optimizer='adam', loss='sparse_categorical_crossentropy')


In [None]:
multimodal_model.summary()

## Step 3: Prepare the Dataset for Training
In this step, we prepare our dataset, converting images and text into tensor format to be compatible with TensorFlow's training process.

In [None]:
# Define the architecture of the multimodal model
# We start with a pretrained ResNet50 model for image feature extraction
# The image input shape is set to (224, 224, 3), which is common for pre-trained models
# The top layers are removed to allow for custom layers, and we freeze the layers to avoid training them
import tensorflow as tf
from transformers import AutoTokenizer

# Preprocess the images for ResNet50
def preprocess_image(image):
    image = tf.image.resize(image, (224, 224))
    image = tf.keras.applications.resnet50.preprocess_input(image)  # This scales pixels to [-1, 1]
    return image

# Preprocess the captions for FLAN-T5
tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-base")

def preprocess_caption(caption):
    encoding = tokenizer(caption, padding='max_length', truncation=True, max_length=128, return_tensors="tf")
    return encoding['input_ids'], encoding['attention_mask']

def preprocess_function(batch):
    # Preprocess images in the batch
    images = [preprocess_image(tf.keras.preprocessing.image.img_to_array(item)) for item in batch['image']]
    images = tf.stack(images)  # Stack into a single tensor

    # Preprocess captions in the batch
    input_ids, attention_mask = preprocess_caption(batch['text'])

    # Create decoder input ids (shifted input_ids) with the same length
    decoder_input_ids = tf.concat([tf.fill([tf.shape(input_ids)[0], 1], tokenizer.pad_token_id), input_ids[:, :-1]], axis=1)

    # Ensure that all tensors are the correct shape (i.e., 128 tokens)
    input_ids = tf.ensure_shape(input_ids, [None, 128])
    attention_mask = tf.ensure_shape(attention_mask, [None, 128])
    decoder_input_ids = tf.ensure_shape(decoder_input_ids, [None, 128])

    # Return a dictionary with tensors
    return {
        'input_ids': input_ids,  # Tensor of input_ids
        'attention_mask': attention_mask,  # Tensor of attention_mask
        'decoder_input_ids': decoder_input_ids,  # Shifted input_ids for decoder
        'decoder_attention_mask': attention_mask,  # Usually same as attention_mask for decoder
        'image_input': images  # Tensor of processed images
    }

# Apply the preprocessing function
processed_dataset = dataset.map(preprocess_function, batched=True)


In [None]:
# Prepare the dataset for training using TensorFlow
# TensorFlow datasets are used for efficient data handling and batching during training
# Convert necessary components of the dataset (image inputs and text inputs) into tensors
def prepare_dataset(dataset):
    input_ids_tensor = tf.convert_to_tensor(dataset['input_ids'])  # Convert list to TensorFlow tensor
    attention_mask_tensor = tf.convert_to_tensor(dataset['attention_mask'])  # Convert attention mask as well
    decoder_input_ids_tensor = tf.convert_to_tensor(dataset['decoder_input_ids'])  # Convert decoder input ids
    decoder_attention_mask_tensor = tf.convert_to_tensor(dataset['decoder_attention_mask'])  # Convert decoder attention mask
    image_input_tensor = tf.convert_to_tensor(dataset['image_input'])  # Convert image data

    return tf.data.Dataset.from_tensor_slices(({
                'input_ids': input_ids_tensor,  # Inputs
                'attention_mask': attention_mask_tensor,  # Attention mask for inputs
                'decoder_input_ids': decoder_input_ids_tensor,  # Decoder inputs
                'decoder_attention_mask': decoder_attention_mask_tensor,  # Attention mask for decoder
                'image_input': image_input_tensor  # Image inputs
            },
            decoder_input_ids_tensor  # Targets, shifted by one
           )).shuffle(1000).batch(16)

train_dataset = prepare_dataset(processed_dataset)


In [None]:
for batch in train_dataset.take(1):
    print(batch[0]['input_ids'].shape)
    print(batch[0]['attention_mask'].shape)
    print(batch[0]['decoder_input_ids'].shape)
    print(batch[0]['decoder_attention_mask'].shape)
    print(batch[0]['image_input'].shape)
    print(batch[1].shape)

## Step 4: Train the Model
Finally, we train the model using the prepared dataset.

In [None]:
# Finally, train the multimodal model by passing the prepared dataset
# We train for 5 epochs (iterations over the dataset)

# Train the model
history = multimodal_model.fit(train_dataset, epochs=5)
