# **Image Captioning using ResNet-152 and LSTM**

## **Part 2: Training the Decoder LSTM for Caption Generation**

Once images have been encoded into feature vectors, a **Long Short-Term Memory (LSTM) network** is trained as a decoder to generate textual descriptions of the images.

### Steps:

1. **Initialize the LSTM Decoder**: The decoder takes the image feature vector (from the CNN encoder) as the initial input and processes sequential word embeddings to generate captions.
2. **Use an Embedding Layer**: Words in the vocabulary are mapped to dense vector representations to improve learning.
3. A **fully connected layer** maps the LSTM output to the vocabulary space.
4. **Optimize with Cross-Entropy Loss**: The decoder is optimized using a suitable loss function (categorical cross-entropy) to minimize the difference between predicted and actual words.
5. **Incorporate Teacher Forcing**: During training, the model uses actual words from the training set instead of its own predictions to improve learning stability.

The trained LSTM decoder enables the generation of fluent and contextually relevant captions for previously unseen images.

## **Technologies & Libraries**
- **Deep Learning Framework:** PyTorch and Tensforflow
- **Pretrained Model:** ResNet-152
- **Decoder Model**: custom LSTM-based decoder, developed in PyTorch
- **Text Processing:** Hugging Face Tokenizers
- **Dataset:** Flickr8k
- **Training Strategy:** Teacher Forcing

---

In [1]:
import sys
sys.path.insert(0,'../')

from tqdm import tqdm 
import torch
import pickle

from Pipeline.modelling.models.DecoderLSTM import *
from Pipeline.modelling.dataloader.Feature_Caption_Dataloader import *

### Load the encoded features and the captions

In [2]:
# load features
encoded_features = torch.load('./data/encoded_images_features.pt', weights_only=True)

# load captions
# Open the file in read-binary mode
with open("data/captions.pkl", "rb") as file:
    # Load the serialized list
    captions = pickle.load(file)

### Create a dataloader for the features

In [3]:
decoder_dataloader = Feature_Caption_Dataloader(encoded_features, captions, batch_size=64, shuffle=False)

### Create an instance of the Decoder

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [5]:
# use the same tokenizer used in the encoder
with open("data/custom_tokenizer.pkl", "rb") as file:
    # Load the serialized list
    tokenizer = pickle.load(file)

In [6]:
decoderLSTM = DecoderLSTM(
    vocab_size=tokenizer.my_vocabulary.__len__(),
    tokenizer=tokenizer
)

In [7]:
epoch_number = 30
optimiser = torch.optim.Adam(decoderLSTM.parameters(), lr=0.001)
loss_function = nn.CrossEntropyLoss()

decoder = decoderLSTM.to(device)

for epoch in tqdm(range(epoch_number)):
    loss_value = 0
    
    for features, targets, lengths in decoder_dataloader:
        features = features.to(device)
        targets = targets.to(device)
        
        optimiser.zero_grad()
        outputs = decoder(features, targets, lengths)
        
        targets = pack_padded_sequence(targets, lengths, batch_first=True) # pack the targets the same way as the outputs
        loss_value = loss_function(outputs, targets.data)
        loss_value.backward()
        
        optimiser.step()
        
    if epoch % 1 == 0:
        print(f'Epoch {epoch} loss: {loss_value:.4f}')

  3%|▎         | 1/30 [00:09<04:21,  9.00s/it]

Epoch 0 loss: 2.9793


  7%|▋         | 2/30 [00:17<04:07,  8.83s/it]

Epoch 1 loss: 2.6482


 10%|█         | 3/30 [00:26<03:57,  8.80s/it]

Epoch 2 loss: 2.3932


 13%|█▎        | 4/30 [00:35<03:48,  8.80s/it]

Epoch 3 loss: 2.2153


 17%|█▋        | 5/30 [00:44<03:40,  8.80s/it]

Epoch 4 loss: 2.0772


 20%|██        | 6/30 [00:52<03:31,  8.80s/it]

Epoch 5 loss: 1.9144


 23%|██▎       | 7/30 [01:01<03:22,  8.81s/it]

Epoch 6 loss: 1.7602


 27%|██▋       | 8/30 [01:10<03:13,  8.81s/it]

Epoch 7 loss: 1.6099


 30%|███       | 9/30 [01:19<03:03,  8.75s/it]

Epoch 8 loss: 1.4845


 33%|███▎      | 10/30 [01:27<02:54,  8.70s/it]

Epoch 9 loss: 1.3577


 37%|███▋      | 11/30 [01:36<02:45,  8.70s/it]

Epoch 10 loss: 1.2769


 40%|████      | 12/30 [01:45<02:37,  8.73s/it]

Epoch 11 loss: 1.1715


 43%|████▎     | 13/30 [01:53<02:28,  8.72s/it]

Epoch 12 loss: 1.0948


 47%|████▋     | 14/30 [02:02<02:20,  8.76s/it]

Epoch 13 loss: 1.0256


 50%|█████     | 15/30 [02:11<02:11,  8.79s/it]

Epoch 14 loss: 0.9424


 53%|█████▎    | 16/30 [02:20<02:03,  8.81s/it]

Epoch 15 loss: 0.8615


 57%|█████▋    | 17/30 [02:29<01:54,  8.81s/it]

Epoch 16 loss: 0.8065


 60%|██████    | 18/30 [02:38<01:45,  8.80s/it]

Epoch 17 loss: 0.7581


 63%|██████▎   | 19/30 [02:46<01:36,  8.80s/it]

Epoch 18 loss: 0.6869


 67%|██████▋   | 20/30 [02:55<01:27,  8.79s/it]

Epoch 19 loss: 0.6402


 70%|███████   | 21/30 [03:04<01:19,  8.79s/it]

Epoch 20 loss: 0.6029


 73%|███████▎  | 22/30 [03:13<01:10,  8.78s/it]

Epoch 21 loss: 0.5632


 77%|███████▋  | 23/30 [03:21<01:01,  8.77s/it]

Epoch 22 loss: 0.5432


 80%|████████  | 24/30 [03:30<00:52,  8.77s/it]

Epoch 23 loss: 0.5147


 83%|████████▎ | 25/30 [03:39<00:43,  8.73s/it]

Epoch 24 loss: 0.4867


 87%|████████▋ | 26/30 [03:47<00:34,  8.70s/it]

Epoch 25 loss: 0.4506


 90%|█████████ | 27/30 [03:56<00:26,  8.67s/it]

Epoch 26 loss: 0.4484


 93%|█████████▎| 28/30 [04:05<00:17,  8.67s/it]

Epoch 27 loss: 0.4532


 97%|█████████▋| 29/30 [04:13<00:08,  8.66s/it]

Epoch 28 loss: 0.4073


100%|██████████| 30/30 [04:22<00:00,  8.75s/it]

Epoch 29 loss: 0.3918





In [8]:
torch.save(decoder.state_dict(), 'data/trained_decoder.pt')

---