In [7]:
import tensorflow as tf
import transformers
from pickle import load, dump



In [4]:
with open("image_encodings.pkl", "rb") as f:
  image_encodings = load(f)

with open("train_captions.pkl", "rb") as f:
  train_captions = load(f)


In [18]:
captions = dict()
captions = train_captions
image_features = image_encodings

In [24]:

# Tokenize the captions
input_ids = transformers.tokenization_utils.BatchEncoding(captions)
# Create the training data as a list of tuples, each containing an image feature vector and its corresponding caption
train_data = ([image_features[i], input_ids[i]] for i in range(len(input_ids)))

In [26]:

# Load the GPT-2 tokenizer
tokenizer = transformers.GPT2Tokenizer.from_pretrained('gpt2')

# Load the GPT-2 model
model = transformers.TFGPT2LMHeadModel.from_pretrained('gpt2')

# Define the loss function
cross_entropy = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

# Define the optimizer
optimizer = tf.keras.optimizers.Adam(learning_rate=0.0001)

# Fine-tune the model on your image-caption pairs
for epoch in range(10):
    for i, (image_features, captions) in enumerate(train_data):
        # Encode the prompt with the GPT-2 tokenizer
        prompt = "Image features: " + image_features
        input_ids = tokenizer.encode(captions, return_tensors='tf').to('cuda')

        with tf.GradientTape() as tape:
            # Generate text using the GPT-2 model
            logits = model(input_ids)
            loss = cross_entropy(captions, logits[0, :-1, :])

        # Backpropagate the loss and update the weights
        gradients = tape.gradient(loss, model.trainable_variables)
        optimizer.apply_gradients(zip(gradients, model.trainable_variables))

        if i % 100 == 0:
            print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'.format(epoch+1, 10, i+1, len(train_data), loss.numpy()))

All model checkpoint layers were used when initializing TFGPT2LMHeadModel.

All the layers of TFGPT2LMHeadModel were initialized from the model checkpoint at gpt2.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFGPT2LMHeadModel for predictions without further training.


In [30]:
model.save_weights("gpt2_model.h5")