# Generating Image Captions

This is a project that combines a CNN model with an LSTM model for generating image captions.
A pretrained CNN is fine-tuned for feature extraction from the given image and the features are passes as input into the LSTM which generates the captions.


"model_name..." (url...) was used as the pretrained CNN
The "name..." dataset was used for the combined models, it can be found here ...

## Importing the libraries

In [None]:
%pip install torch torchvision pandas nltk

In [None]:
import torch
import torch.nn as nn
import torchvision.models as models
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Dataset
import torch.optim as optim
import random

## Choosing the device

In [None]:
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print("mps" if torch.backends.mps.is_available() else "cpu")

## Loading the captions from the dataset

Loading the dataset:

In [None]:
import pandas as pd
import os
from PIL import Image

import re
from collections import Counter

import nltk
nltk.download('punkt_tab')
from nltk.tokenize import word_tokenize

# Paths to your dataset
image_folder = './flickr30k_images/flickr30k_images'
captions_file = './flickr30k_images/results.csv'

# Load captions
captions = pd.read_csv(captions_file, sep='\t', header=None, names=['image', 'caption'])
captions['image'] = captions['image'].apply(lambda x: x.split('#')[0])


def split_string(input_str):
    parts = input_str.split('|')
    parts = [part.strip() for part in parts]
    parts[1] = int(parts[1])
    return parts

# Display an example
captions = captions['image'][1:].apply(split_string)


# Clean and tokenize captions
def preprocess_caption(element):
    caption = element[2]
    
    caption = caption.lower()
    caption = re.sub(r'[^\w\s]', '', caption)  # Remove punctuation

    element.append(word_tokenize(caption) + ['<EOS>'])  # emp is used to pad the caption to a fixed length with empty space

    return element

captions_tokenized = captions.apply(preprocess_caption)
# print(captions_tokenized)

# Build vocabulary
# all_words = [word for tokens in captions['tokens'] for word in tokens]
# vocab = Counter(all_words)
# vocab_size = len(vocab)

# convert vocab to a python vocabulary with keys as image names and values as list of captions
vocab_dict = {}
for image_name, _, _, caption_tokens in captions_tokenized:
    image_name = image_name.replace('.jpg', '')
    if image_name not in vocab_dict:
        vocab_dict[image_name] = []
    vocab_dict[image_name].append(caption_tokens)


Checking if the captins were loaded well:

In [None]:
if '1000268201' in vocab_dict:
    print(vocab_dict['1000268201'])
else:
    print('Not found')
if '4153903524' in vocab_dict:
    print(vocab_dict['4153903524'][0])
    print(vocab_dict['4153903524'][1])
    print(vocab_dict['4153903524'])
else:
    print('Not found')

Creating the vocabulary:

In [None]:
# Creating the vocabulary based on all the tokenized captions

all_words = [word for item in captions_tokenized for word in item[3]]

vocab_counter = Counter(all_words)
vocab_size = len(vocab_counter)

# Print the results
# print(all_words[0:100])
# print(vocab_counter)

# Get the most common tokens (sorted by frequency)
most_common_tokens = vocab_counter.most_common()

# Create the vocabulary dictionary
vocab = {token: idx for idx, (token, _) in enumerate(most_common_tokens)}
# print('vocab: ', vocab)

def sentence_to_tensor(sentence):
    return torch.tensor([vocab[token] for token in sentence])

max_len = 0
longest = 0
total = 0
longer_tnan_20 = 0
for i in range(158915):
    if i == 0: continue
    total += len(captions_tokenized[i][3])
    if len(captions_tokenized[i][3]) > max_len:
        longest = captions_tokenized[i]
        max_len = len(captions_tokenized[i][3])

    if len(captions_tokenized[i][3]) > 20:
        longer_tnan_20 += 1

print('max_len: ', max_len)
print('longest: ', longest[3])
print('average: ', total/158914)
print('longer_than_20: ', longer_tnan_20)
# print('example tokens:   ', captions_tokenized[2][3])
# print('example tensor: ', sentence_to_tensor(captions_tokenized[2][3]))


Function to get list of captions:

In [None]:
def get_captions(image_name):
    if image_name not in vocab_dict:
        return None
    return vocab_dict[image_name]

print(get_captions('4153903524'))

## Loading the images from the dataset

In [None]:
import os
from PIL import Image

BATCH_SIZE = 32

class CustomImageDataset(Dataset):
    def __init__(self, image_folder, transform=None):
        self.image_folder = image_folder
        self.image_files = [f for f in os.listdir(image_folder) if f.endswith('.jpg')]
        self.transform = transform

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_name = self.image_files[idx]
        img_path = os.path.join(self.image_folder, img_name)
        image = Image.open(img_path)
        
        # Extract label from the filename (remove extension)
        label = os.path.splitext(img_name)[0]
        
        if self.transform:
            image = self.transform(image)
        
        return image, label

# Example usage:
image_folder = './flickr30k_images/flickr30k_images'
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

dataset = CustomImageDataset(image_folder=image_folder, transform=transform)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

# Example loop through the dataloader
# counter = 0
# for images, labels in dataloader:
#     if counter > 5:
#         break
#     counter += 1
#     print(labels)  # this will print the image labels

## CNN model (from ResNet18)

In [None]:
# Load pre-trained ResNet18 model
cnn_model = models.resnet18(pretrained=True)

CNN_OUTPUT_SIZE = 256

count = 0
for param in cnn_model.parameters():
    param.requires_grad = False

# Replace the last layer with a new, untrained layer
num_features = cnn_model.fc.in_features # the number of features in the last layer of the model (512 for ResNet18)
cnn_model.fc = nn.Linear(num_features, CNN_OUTPUT_SIZE)

# Set the last layer to be trainable
for param in cnn_model.fc.parameters():
    param.requires_grad = True

cnn_model = cnn_model.to(device)


## LSTM model implementation

Helper function for the LSTM:

In [None]:
def output_to_token(output):
    _, predicted = torch.max(output, 1)
    return predicted

In [None]:
class LSTM(nn.Module):
    def __init__(self, input_size, memory_size, output_size, max_output_len):
        super(LSTM, self).__init__()

        self.max_output_len = max_output_len
        self.input_size = input_size
        self.memory_size = memory_size
        self.output_size = output_size

        self.sigmoid = nn.Sigmoid()
        self.tanh = nn.Tanh()

        self.embedding = nn.Embedding(vocab_size, input_size)

        # self.output_to_input = nn.Linear(output_size, input_size) # this is used for all but the first step

        self.forget_x = nn.Linear(input_size, memory_size) # this is used for the % of the memory to remember
        self.forget_h = nn.Linear(memory_size, memory_size) # -||-

        self.input_x = nn.Linear(input_size, memory_size) # this is used for the % of a new memory to remember
        self.input_h = nn.Linear(memory_size, memory_size) # -||-

        self.output_x = nn.Linear(input_size, memory_size) # this is used to generate the output of the cell
        self.output_h = nn.Linear(memory_size, memory_size) # -||-

        self.new_memory_x = nn.Linear(input_size, memory_size) # this is used to generate the new memory
        self.new_memory_h = nn.Linear(memory_size, memory_size) # -||-

        self.output_layer = nn.Linear(memory_size, output_size) # this is used to generate the final output

        # Learnable initial hidden and cell states
        self.h_0 = nn.Parameter(torch.zeros(1, memory_size))
        self.c_0 = nn.Parameter(torch.zeros(1, memory_size))

    def forward_step(self, input_tensor, prev_hidden, prev_state):
        # implement the logic for a single forward pass step for the LSTM (the formulas from torch documentation)
        i = self.sigmoid(self.input_x(input_tensor) + self.input_h(prev_hidden))
        f = self.sigmoid(self.forget_x(input_tensor) + self.forget_h(prev_hidden))
        o = self.sigmoid(self.output_x(input_tensor) + self.output_h(prev_hidden))
        g = self.tanh(self.new_memory_x(input_tensor) + self.new_memory_h(prev_hidden))

        c = f * prev_state + i * g    # internal state
        new_hidden = o * self.tanh(c) # short term memory
        return new_hidden, c
    
    def forward(self, initial_input, ground_truth=None, teacher_forcing_ratio=0.5):
        # working with the initial_input
        batch_size = initial_input.size(0)
        hidden = self.h_0.repeat(batch_size, 1).to(device)
        state = self.c_0.repeat(batch_size, 1).to(device)

        # capturing the outputs in a list
        outputs = torch.zeros(self.max_output_len, batch_size, self.output_size).to(device)

        # performing the first forward step
        hidden, state = self.forward_step(initial_input, hidden, state)
        outputs[0] = self.output_layer(hidden)

        # performing the rest of the forward steps
        for i in range(1, self.max_output_len):
            # if teacher forcing is used, the next input is the ground truth
            if ground_truth is not None and random.random() < teacher_forcing_ratio:
                # setting the input tensor to the ground truth for the previous step output
                prev_output = ground_truth[:, i-1]
            else:
                # getting the previous output and using it as the new input
                prev_output = outputs[i-1].clone()

            _, indices = prev_output.max(dim=1)
            # print('indices: ', indices, "; i = ", i)
            input_tensor = self.embedding(indices)

            # calling the forward step
            hidden, state = self.forward_step(input_tensor, hidden, state)
            outputs[i] = self.output_layer(hidden)
        
        return outputs.transpose(0, 1)

MAX_OUTPUT_LEN = 20
lstm_model = LSTM(CNN_OUTPUT_SIZE, 512, vocab_size, MAX_OUTPUT_LEN).to(device)


## Combining the CNN and LSTM

In [None]:
class CaptionModel(nn.Module):
    def __init__(self):
        super(CaptionModel, self).__init__()
        self.cnn_model = cnn_model
        self.lstm_model = lstm_model

    def forward(self, image_tensor, ground_truth=None, teacher_forcing_ratio=0.5):
        image_features = self.cnn_model(image_tensor)
        captions = self.lstm_model(image_features, ground_truth, teacher_forcing_ratio)
        return captions

caption_model = CaptionModel().to(device)

## Training Preparation

Some helper functions:

In [None]:
def indices_to_tokens(output):
    sentence = []
    for batch in output:
        batch_sentence = []
        for tokens in batch:
            _, token_index = torch.max(tokens, dim=0)
            batch_sentence.append(list(vocab.keys())[token_index])
            # if token_index == vocab['<EOS>']:
            #     break
        sentence.append(batch_sentence)
    return sentence

def tokens_to_indices(sentence):
    indices = [vocab[token] for token in sentence]
    if len(indices) > MAX_OUTPUT_LEN:
        indices = indices[:MAX_OUTPUT_LEN]
    elif len(indices) < MAX_OUTPUT_LEN:
        indices += [vocab['<EOS>']] * (MAX_OUTPUT_LEN - len(indices))
    return torch.tensor(indices)

def indexes_to_onehot(indexes):
    # indexes: (batch_size, seq_len)
    onehot = torch.zeros(indexes.size(0), indexes.size(1), len(vocab)).to(device)
    onehot.scatter_(2, indexes.unsqueeze(2), 1)
    return onehot

def calculate_teacher_forcing_ratio(epoch, max_epochs):
    ratio = 1 - epoch / max_epochs
    return 0.1 + ratio*0.8

Loss and Criterion:

In [None]:
import torch.optim.lr_scheduler as lr_scheduler

loss_weights = torch.ones(len(vocab)).to(device)
loss_weights[0] = 30 # reducing the importance of the token 'a'
loss_weights[1] = 30 # reducing the importance of the <EOS> token

criterion = nn.CrossEntropyLoss(
    # ignore_index=vocab['<EOS>'], 
    # reduction='mean',
    # weight=loss_weights
)
# criterion = nn.CrossEntropyLoss(reduction='mean')
optimizer = optim.SGD(caption_model.parameters(), lr=0.1, momentum=0.9)
optimizer = optim.Adam(caption_model.parameters(), lr=0.001)
scheduler = lr_scheduler.StepLR(optimizer, step_size=900, gamma=0.966)

# Example usage:
# for i in range(5*990):
#     scheduler.step()
#     print(optimizer.state_dict()['param_groups'][0]['lr'])

## Some tests before training

In [None]:
data_iter = iter(dataloader)
first_batch = next(data_iter)

# The batch will contain both inputs (images) and labels
inputs, labels = first_batch

# You can now work with 'inputs' and 'labels'
print("First batch inputs shape:", inputs.shape)
print("First batch labels:", labels)

In [None]:
# this is how to get the captions for given file batch with names in the list 'labels'
captions = [get_captions(f'{label}') for label in labels]
print(captions)
print(len(captions), len(captions[0]))

shortest_captions = [min(captions[i], key=len) for i in range(len(captions))]
shortest_indices = [captions[i].index(shortest_captions[i]) for i in range(len(captions))]
print('shortest_indices: ', shortest_indices)

target_output = []
for i, caption_list in enumerate(captions):
    target_output.append(tokens_to_indices(caption_list[shortest_indices[i]]))

target_output = torch.stack(target_output).to(device)
print('target_output.shape: ', target_output.shape)

onehot_target = indexes_to_onehot(target_output)
print('onehot_target.shape: ', onehot_target.shape)


In [None]:
# testing the forward pass
inputs = inputs.to(device)
outputs = caption_model(inputs, onehot_target, 0.5)
# print(outputs.shape)
# print(outputs)

print('outputs shape', outputs.shape)
print('target output shape: ', target_output.shape)


# print(BATCH_SIZE * MAX_OUTPUT_LEN, vocab_size)
outputs_flat = outputs.reshape(BATCH_SIZE * MAX_OUTPUT_LEN, vocab_size)
target_output_flat = target_output.reshape(BATCH_SIZE * MAX_OUTPUT_LEN)

print("shapes: ", outputs_flat.shape, target_output_flat.shape)

loss = criterion(outputs_flat, target_output_flat)
print(loss.item())

# print(indices_to_tokens(outputs))

# loss = criterion(outputs, labels)
# loss.backward()
# optimizer.step()


## Training the Model

In [None]:
epochs = 10
accumulation_steps = 10
total_loss = 0

checkpoint_path = 'checkpoint_with_embeddings.pth'

# torch.autograd.set_detect_anomaly(True)

for epoch in range(epochs):
    optimizer.zero_grad()
    for i, (inputs, labels) in enumerate(dataloader):
        if epoch==0 and i==0 and os.path.exists(checkpoint_path):
            # loading the state of the model from the checkpoint
            print('loading the states from the checkpoint')

            checkpoint = torch.load(checkpoint_path)
            caption_model.load_state_dict(checkpoint['model_state'])
            optimizer.load_state_dict(checkpoint['optimizer_state'])
            scheduler.load_state_dict(checkpoint['scheduler_state'])
            epoch = checkpoint['epoch']

        # print(i)
        if i == 990:
            # making a checkpoint of the current model state
            checkpoint = {
                "epoch": epoch,
                "model_state": caption_model.state_dict(),
                "optimizer_state": optimizer.state_dict(),
                "scheduler_state": scheduler.state_dict()
            }
            torch.save(checkpoint, checkpoint_path)
            break
    
        # here we get all 5 captions for the images in the batch
        captions = [get_captions(f'{label}') for label in labels]

        # here we get the indices of the tokens in the captions
        target_output = []
        shortest_captions = [min(captions[i], key=len) for i in range(len(captions))]
        shortest_indices = [captions[i].index(shortest_captions[i]) for i in range(len(captions))]
        for j, caption_list in enumerate(captions):
            target_output.append(tokens_to_indices(caption_list[shortest_indices[j]]))
        target_output = torch.stack(target_output).to(device)

        # forward pass
        inputs = inputs.to(device)
        outputs = caption_model(
            inputs,                                         # image features tensor
            indexes_to_onehot(target_output),               # target output tensor (onehot encoded tokens)
            calculate_teacher_forcing_ratio(epoch, epochs)  # teacher forcing ratio (calculated based on the epoch)
            # 1
        )

        # flattening the outputs and target_output tensors
        outputs_flat = outputs.reshape(BATCH_SIZE * MAX_OUTPUT_LEN, vocab_size)
        target_output_flat = target_output.reshape(BATCH_SIZE * MAX_OUTPUT_LEN)

        # calculate loss
        loss = criterion(outputs_flat, target_output_flat)

        # accumulate gradients
        loss = loss / accumulation_steps
        loss.backward()

        total_loss += loss.item()

        if (i+1) % accumulation_steps == 0:
            # update weights
            optimizer.step()
            optimizer.zero_grad()
            scheduler.step()
            print(f'Epoch {epoch + 1}, batch {i + 1}, loss: {total_loss}')
            # print(f'Example image: {labels[0]} --> Caption: "{output_to_token(outputs[1])}"')
            total_loss = 0.0


Testing the model:

In [None]:
import random
import torch
from PIL import Image
from torchvision import transforms

checkpoint_path = 'checkpoint_with_embeddngs.pth'

# Set the device
device = torch.device("mps" if torch.cuda.is_available() else "cpu")

# Define the image folder and transform
image_folder = './flickr30k_images/flickr30k_images'
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Load the trained model
checkpoint = torch.load(checkpoint_path)
caption_model.load_state_dict(checkpoint['model_state'])
caption_model.to(device)
caption_model.eval()

# Get a random image from the dataset
random_image = random.choice(dataset.image_files)
image_path = os.path.join(image_folder, random_image)
image = Image.open(image_path)
print(image_path)

# Preprocess the image
image_tensor = transform(image).unsqueeze(0).to(device)

# Pass the image through the model
with torch.no_grad():
    cnn_outputs = caption_model.cnn_model(image_tensor)
    outputs = caption_model(image_tensor)

print(cnn_outputs[0][0:20])

# Convert the output to tokens
output_tokens = indices_to_tokens(outputs)
# Format the output as a sentence
output_sentence = ' '.join(output_tokens[0])

# Print the output sentence
print(output_sentence)
