<a href="https://colab.research.google.com/github/Gopib03/Pytorch_Models/blob/main/image_caption_generator_using_PyTorch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Image caption generator using PyTorch

Here , we use Common Objects in Context (COCO) dataset
This dataset consists of over 200,000 labeled images with five captions for each image.


We are using a slightly older version of the dataset as it is slightly smaller in size, enabling us to get the results faster.


The training and validation datasets are 13 GB and 6 GB in size, respectively. Downloading and extracting the dataset files, as well as cleaning and processing them

Downloading the image captioning datasets

In [None]:
# download images and annotations to the data directory
!wget http://images.cocodataset.org/annotations/annotations_trainval2014.zip -P ./data_dir/ -P ./data_dir/
!wget http://images.cocodataset.org/zips/train2014.zip -P ./data_dir/
!wget http://images.cocodataset.org/zips/val2014.zip -P ./data_dir/
# extract zipped images and annotations and remove the zip files
!unzip ./data_dir/annotations_trainval2014.zip -d ./data_dir/
!rm ./data_dir/annotations_trainval2014.zip
!unzip ./data_dir/train2014.zip -d ./data_dir/
!rm ./data_dir/train2014.zip
!unzip ./data_dir/val2014.zip -d ./data_dir/
!rm ./data_dir/val2014.zip


--2026-01-02 21:53:14--  http://images.cocodataset.org/annotations/annotations_trainval2014.zip
Resolving images.cocodataset.org (images.cocodataset.org)... 3.5.27.159, 3.5.28.79, 3.5.28.218, ...
Connecting to images.cocodataset.org (images.cocodataset.org)|3.5.27.159|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 252872794 (241M) [application/zip]
Saving to: ‘./data_dir/annotations_trainval2014.zip’


2026-01-02 21:53:20 (38.9 MB/s) - ‘./data_dir/annotations_trainval2014.zip’ saved [252872794/252872794]

--2026-01-02 21:53:21--  http://images.cocodataset.org/zips/train2014.zip
Resolving images.cocodataset.org (images.cocodataset.org)... 3.5.28.132, 3.5.27.126, 3.5.31.111, ...
Connecting to images.cocodataset.org (images.cocodataset.org)|3.5.28.132|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 13510573713 (13G) [application/zip]
Saving to: ‘./data_dir/train2014.zip’


import a few dependencies, Some of the crucial modules

In [None]:
import nltk
from pycocotools.coco import COCO
import torch.utils.data as data
import torchvision.models as models
import torchvision.transforms as transforms
from torch.nn.utils.rnn import pack_padded_sequence
from collections import Counter
import os
from PIL import Image
import torch
import torch.nn as nn

Besides importing the nltk library, we will also need to download its punkt tokenizer model

In [None]:
nltk.download('punkt')
nltk.download('punkt_tab')

In [None]:
def build_vocabulary(json, threshold):
  """Build a vocab wrapper."""
  coco = COCO(json)
  counter  = Counter()
  ids = coco.anns.keys()
  for i, id in enumerate(ids):
    caption = str(coco.anns[id]['caption'])
    tokens = nltk.tokenize.word_tokenize(caption.lower())
    counter.update(tokens)
    if(i+1)%1000 == 0:
      print("[{}/{}]Tokenized the captions.".format(i+1, len(ids)))

In [None]:
class Vocab(object):
    """Simple vocabulary wrapper."""
    def __init__(self):
        self.word_to_id = {}
        self.id_to_word = {}
        self.idx = 0

    def add_token(self, word):
        if not word in self.word_to_id:
            self.word_to_id[word] = self.idx
            self.id_to_word[self.idx] = word
            self.idx += 1

    def __call__(self, word):
        if not word in self.word_to_id:
            return self.word_to_id['<unk>']
        return self.word_to_id[word]

    def __len__(self):
        return self.idx

def build_vocabulary(json, threshold):
  """Build a vocab wrapper."""
  coco = COCO(json)
  counter  = Counter()
  ids = coco.anns.keys()
  for i, id in enumerate(ids):
    caption = str(coco.anns[id]['caption'])
    tokens = nltk.tokenize.word_tokenize(caption.lower())
    counter.update(tokens)
    if(i+1)%1000 == 0:
      print("[{}/{}]Tokenized the captions.".format(i+1, len(ids)))

  # If word freq < 'thres', then word is discarded.
  tokens = [token for token,
            cnt in counter.items() if cnt >= threshold]
  # Create vocab wrapper + add special tokens.
  vocab = Vocab()
  vocab.add_token('<pad>')
  vocab.add_token('<start>')
  vocab.add_token('<end>')
  vocab.add_token('<unk>')
  # Add words to vocab.
  for i, token in enumerate(tokens):
      vocab.add_token(token)
  return vocab

In [None]:
import pickle
vocab = build_vocabulary(
    json='data_dir/annotations/captions_train2014.json', threshold=4)
vocab_path = './data_dir/vocabulary.pkl'
with open(vocab_path, 'wb') as f:
    pickle.dump(vocab, f)
print("Total vocabulary size: {}".format(len(vocab)))
print("Saved the vocabulary wrapper to '{}'"
      .format(vocab_path))

In [7]:
def reshape_images(image_path, output_path, shape):
    images = os.listdir(image_path)
    num_im = len(images)
    for i, im in enumerate(images):
        with open(os.path.join(image_path, im), 'r+b') as f:
            with Image.open(f) as image:
                image = reshape_image(image, shape)
                image.save(os.path.join(output_path, im),
                           image.format)
        if (i+1) % 100 == 0:
            print ("[{}/{}] Resized the images and saved into '{}'."
                   .format(i+1, num_im, output_path))

# Define the missing variables
image_path = './data_dir/val2014' # Example path, adjust as needed
output_path = './data_dir/resized_images_val2014' # Example path, adjust as needed
image_shape = [256, 256] # Example shape, adjust as needed

# Ensure the output directory exists
if not os.path.exists(output_path):
    os.makedirs(output_path)

# Note: `reshape_image` function is still undefined and will cause a NameError if not defined elsewhere.
# For now, let's assume `reshape_image` is a simple resize function for demonstration:
def reshape_image(image, shape):
    return image.resize(shape, Image.LANCZOS)

reshape_images(image_path, output_path, image_shape)

[100/40504] Resized the images and saved into './data_dir/resized_images_val2014'.
[200/40504] Resized the images and saved into './data_dir/resized_images_val2014'.
[300/40504] Resized the images and saved into './data_dir/resized_images_val2014'.
[400/40504] Resized the images and saved into './data_dir/resized_images_val2014'.
[500/40504] Resized the images and saved into './data_dir/resized_images_val2014'.
[600/40504] Resized the images and saved into './data_dir/resized_images_val2014'.
[700/40504] Resized the images and saved into './data_dir/resized_images_val2014'.
[800/40504] Resized the images and saved into './data_dir/resized_images_val2014'.
[900/40504] Resized the images and saved into './data_dir/resized_images_val2014'.
[1000/40504] Resized the images and saved into './data_dir/resized_images_val2014'.
[1100/40504] Resized the images and saved into './data_dir/resized_images_val2014'.
[1200/40504] Resized the images and saved into './data_dir/resized_images_val2014'.
[

In [8]:
def get_loader(data_path, coco_json_path, vocabulary, transform, batch_size, shuffle):
    # COCO dataset
    coco_dataset = CustomCocoDataset(data_path=data_path,
                       coco_json_path=coco_json_path,
                       vocabulary=vocabulary,
                       transform=transform)
    custom_data_loader = \
        torch.utils.data.DataLoader(dataset=coco_dataset,
                                batch_size=batch_size,
                                shuffle=shuffle,
                                collate_fn=collate_function)
    return custom_data_loader


In [9]:
# This cell is now redundant as collate_function is defined in cell 'dJKSjQ6Oj2Ah'.

In [10]:
def get_loader(data_path, coco_json_path, vocabulary, transform, batch_size=128, shuffle=True, num_workers=2):
    # COCO dataset
    coco_dataset = CustomCocoDataset(data_path=data_path,
                       coco_json_path=coco_json_path,
                       vocabulary=vocabulary,
                       transform=transform)
    custom_data_loader = \
        torch.utils.data.DataLoader(dataset=coco_dataset,
                                batch_size=batch_size,
                                shuffle=shuffle,
                                num_workers=num_workers,
                                collate_fn=collate_function)
    return custom_data_loader

The CNN_LSTM model

In [11]:
class CNNModel(nn.Module):
    def __init__(self, embedding_size):
        """Load pretrained ResNet-152 & replace
        last fully connected layer."""
        super(CNNModel, self).__init__()
        resnet = models.resnet152(pretrained=True)
        module_list = list(resnet.children())[:-1]
        # delete last fully connected layer.
        self.resnet_module = nn.Sequential(*module_list)
        self.linear_layer = nn.Linear(
            resnet.fc.in_features, embedding_size)
        self.batch_norm = nn.BatchNorm1d(
            embedding_size, momentum=0.01)
    def forward(self, input_images):
        """Extract feats from images."""
        with torch.no_grad():
            resnet_features = self.resnet_module(input_images)
            resnet_features = resnet_features.reshape(
                                   resnet_features.size(0), -1)
            final_features = self.batch_norm(
                                self.linear_layer(
                                    resnet_features))
        return final_features

We have defined two sub-models – that is, a CNN model and an RNN model. For the CNN part, we use a pretrained CNN model available under the PyTorch models repository: the ResNet 152 architecture. As we have learned in Chapter 2, Deep CNN Architectures, this deep CNN model with 152 layers is pretrained on the ImageNet dataset [5]. The ImageNet dataset contains over 1.4 million RGB images labeled over 1,000 classes. These 1,000 classes belong to categories such as plants, animals, food, sports, and more

the LSTM layer takes in the embedding vectors as input and outputs a sequence of words that should ideally describe the image from which the embedding was generated

In [12]:
class LSTMModel(nn.Module):
    def __init__(self, embedding_size, hidden_layer_size,
                 vocabulary_size, num_layers,
                 max_seq_len=20):
        super(LSTMModel, self).__init__()
        self.embedding_layer = nn.Embedding(vocabulary_size, embedding_size)
        self.lstm_layer = nn.LSTM(embedding_size,
                                  hidden_layer_size,
                                  num_layers,
                                  batch_first=True)
        self.linear_layer = nn.Linear(hidden_layer_size,
                                      vocabulary_size)
        self.max_seq_len = max_seq_len

    def forward(self, input_features, captions, lengths):
        """Decode image feature vectors and generate captions."""
        embeddings = self.embedding_layer(captions)
        inputs = torch.cat((input_features.unsqueeze(1), embeddings), 1)
        packed_sequence = pack_padded_sequence(inputs, lengths.cpu(), batch_first=True, enforce_sorted=False)
        hidden_states, _ = self.lstm_layer(packed_sequence)
        outputs = self.linear_layer(hidden_states[0])
        return outputs

    def sample(self, input_features, lstm_states=None):
        """Generate caps for feats with greedy search."""
        sampled_indices = []
        inputs = input_features.unsqueeze(1)

        for i in range(self.max_seq_len):
            hidden_states, cell_states = self.lstm_layer(inputs, lstm_states)
            lstm_states = (hidden_states, cell_states)

            outputs = self.linear_layer(hidden_states.squeeze(1))

            _, predicted = outputs.max(1)

            sampled_indices.append(predicted)

            inputs = self.embedding_layer(predicted)
            inputs = inputs.unsqueeze(1)

        sampled_indices = torch.stack(sampled_indices, 1)
        return sampled_indices

Training the CNN-LSTM model

In [13]:
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available()
                      else 'cpu')


PyTorch’s transform module to normalize the input image pixel values:

In [14]:
# Image pre-processing, normalization for pretrained resnet
transform = transforms.Compose([
    transforms.RandomCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406),
                         (0.229, 0.224, 0.225))])


Preprocessing caption (text) data section. We also initialize the data loader using the get_loader() function defined in the Defining the image captioning data loader section:

In [16]:
# Load vocab wrapper
with open('data_dir/vocabulary.pkl', 'rb') as f:
    vocabulary = pickle.load(f)

# Instantiate data loader
custom_data_loader = get_loader('data_dir/resized_images',
    'data_dir/annotations/captions_train2014.json',
    vocabulary, transform, 128, shuffle=True)


NameError: name 'CustomCocoDataset' is not defined

In [None]:
# Build models
encoder_model = CNNModel(256).to(device)
decoder_model = LSTMModel(256, 512,
                          len(vocabulary), 1).to(device)
# Loss & optimizer
loss_criterion = nn.CrossEntropyLoss()
parameters = list(decoder_model.parameters()) + \
        list(encoder_model.linear_layer.parameters()) + \
        list(encoder_model.batch_norm.parameters())
optimizer = torch.optim.Adam(parameters, lr=0.001)


In [17]:
for epoch in range(5):
    for i, (imgs, caps, lens) in enumerate(custom_data_loader):
        tgts = pack_padded_sequence(caps, lens,
                                    batch_first=True)[0]
        # Forward pass, backward propagation
        feats = encoder_model(imgs)
        outputs = decoder_model(feats, caps, lens)
        loss = loss_criterion(outputs, tgts)
        decoder_model.zero_grad()
        encoder_model.zero_grad()
        loss.backward()
        optimizer.step()


NameError: name 'custom_data_loader' is not defined

# New Section