<a href="https://www.kaggle.com/code/prasannakasar/image-captioning?scriptVersionId=220256340" target="_blank"><img align="left" alt="Kaggle" title="Open in Kaggle" src="https://kaggle.com/static/images/open-in-kaggle.svg"></a>

In [1]:
import pandas as pd
pd.set_option('display.max_colwidth', None)

In [2]:
# df = pd.read_csv("/kaggle/input/flickr8k/captions.txt", sep=",")

In [3]:
BASE_PATH = '../input/coco-2017-dataset/coco2017'

In [4]:
import json

with open(f'{BASE_PATH}/annotations/captions_train2017.json', 'r') as f:
    data = json.load(f)
    data = data['annotations']

img_cap_pairs = []

for sample in data:
    img_name = '%012d.jpg' % sample['image_id']
    img_cap_pairs.append([img_name, sample['caption']])

df = pd.DataFrame(img_cap_pairs, columns=['image', 'caption'])
df['image'] = df['image'].apply(
    lambda x: f'{BASE_PATH}/train2017/{x}'
)
df = df.reset_index(drop=True)
df.head()

Unnamed: 0,image,caption
0,../input/coco-2017-dataset/coco2017/train2017/000000203564.jpg,A bicycle replica with a clock as the front wheel.
1,../input/coco-2017-dataset/coco2017/train2017/000000322141.jpg,A room with blue walls and a white sink and door.
2,../input/coco-2017-dataset/coco2017/train2017/000000016977.jpg,A car that seems to be parked illegally behind a legally parked car
3,../input/coco-2017-dataset/coco2017/train2017/000000106140.jpg,A large passenger airplane flying through the air.
4,../input/coco-2017-dataset/coco2017/train2017/000000106140.jpg,There is a GOL plane taking off in a partly cloudy sky.


In [5]:
len(df)

591753

In [6]:
import re

df['cleaned_caption'] = df['caption'].apply(lambda caption: re.sub(r"[^a-zA-Z0-9 ]", "", caption))  # Remove punctuation
df['cleaned_caption'] = df['cleaned_caption'].apply(lambda caption: caption.lower().split())  # Convert to lowercase and split
df['cleaned_caption'] = df['cleaned_caption'].apply(lambda lis: ['<start>'] + [word for word in lis if word not in {"a", "an", "the"}] + ['<end>'])  # Remove stop words and add tokens


In [7]:
df.head(3)


Unnamed: 0,image,caption,cleaned_caption
0,../input/coco-2017-dataset/coco2017/train2017/000000203564.jpg,A bicycle replica with a clock as the front wheel.,"[<start>, bicycle, replica, with, clock, as, front, wheel, <end>]"
1,../input/coco-2017-dataset/coco2017/train2017/000000322141.jpg,A room with blue walls and a white sink and door.,"[<start>, room, with, blue, walls, and, white, sink, and, door, <end>]"
2,../input/coco-2017-dataset/coco2017/train2017/000000016977.jpg,A car that seems to be parked illegally behind a legally parked car,"[<start>, car, that, seems, to, be, parked, illegally, behind, legally, parked, car, <end>]"


In [8]:
df['seq_len'] = df['cleaned_caption'].apply(lambda x : len(x))
max_len = df['seq_len'].max()
max_len

48

In [9]:
df['cleaned_caption'].apply(len).idxmax()

495542

In [10]:
# df['cleaned_caption'] = df['cleaned_caption'].apply(lambda lis : lis + ['<pad>'] * (max_len - len(lis)))

In [11]:
df.head(3)

Unnamed: 0,image,caption,cleaned_caption,seq_len
0,../input/coco-2017-dataset/coco2017/train2017/000000203564.jpg,A bicycle replica with a clock as the front wheel.,"[<start>, bicycle, replica, with, clock, as, front, wheel, <end>]",9
1,../input/coco-2017-dataset/coco2017/train2017/000000322141.jpg,A room with blue walls and a white sink and door.,"[<start>, room, with, blue, walls, and, white, sink, and, door, <end>]",11
2,../input/coco-2017-dataset/coco2017/train2017/000000016977.jpg,A car that seems to be parked illegally behind a legally parked car,"[<start>, car, that, seems, to, be, parked, illegally, behind, legally, parked, car, <end>]",13


In [12]:
word_list = []
df['cleaned_caption'].apply(lambda lis: [word_list.append(word) for word in lis])

0                                                         [None, None, None, None, None, None, None, None, None]
1                                             [None, None, None, None, None, None, None, None, None, None, None]
2                                 [None, None, None, None, None, None, None, None, None, None, None, None, None]
3                                                               [None, None, None, None, None, None, None, None]
4                                       [None, None, None, None, None, None, None, None, None, None, None, None]
                                                           ...                                                  
591748                                  [None, None, None, None, None, None, None, None, None, None, None, None]
591749                            [None, None, None, None, None, None, None, None, None, None, None, None, None]
591750                                              [None, None, None, None, None, None, None, N

In [13]:
from collections import Counter
word_dict = Counter(word_list)

In [14]:
min_freq = 5
filtered_words = [word for word, freq in word_dict.items() if freq >= min_freq]
filtered_words = ['<unk>', '<pad>'] + filtered_words

In [15]:
word_to_index = {word: idx for idx, word in enumerate(filtered_words)}
index_to_word = {idx: word for idx, word in enumerate(filtered_words)}

In [16]:
df['word_token'] = df['cleaned_caption'].apply(lambda lis : [word_to_index[word] if word in word_to_index else word_to_index['<unk>'] for word in lis])

In [17]:
df.head(2)

Unnamed: 0,image,caption,cleaned_caption,seq_len,word_token
0,../input/coco-2017-dataset/coco2017/train2017/000000203564.jpg,A bicycle replica with a clock as the front wheel.,"[<start>, bicycle, replica, with, clock, as, front, wheel, <end>]",9,"[2, 3, 4, 5, 6, 7, 8, 9, 10]"
1,../input/coco-2017-dataset/coco2017/train2017/000000322141.jpg,A room with blue walls and a white sink and door.,"[<start>, room, with, blue, walls, and, white, sink, and, door, <end>]",11,"[2, 11, 5, 12, 13, 14, 15, 16, 14, 17, 10]"


In [18]:
max_seq_len = max(df['seq_len'])

In [19]:
df.drop(columns=['caption', 'cleaned_caption', 'seq_len'], inplace=True)
df.head(3)

Unnamed: 0,image,word_token
0,../input/coco-2017-dataset/coco2017/train2017/000000203564.jpg,"[2, 3, 4, 5, 6, 7, 8, 9, 10]"
1,../input/coco-2017-dataset/coco2017/train2017/000000322141.jpg,"[2, 11, 5, 12, 13, 14, 15, 16, 14, 17, 10]"
2,../input/coco-2017-dataset/coco2017/train2017/000000016977.jpg,"[2, 18, 19, 20, 21, 22, 23, 24, 25, 0, 23, 18, 10]"


In [20]:
train_size = int(0.9*len(df))
test_size = len(df) - train_size
#/kaggle/input/flickr8k/Images

In [21]:
import torch
import os
from PIL import Image
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms
import matplotlib.pyplot as plt

class ImageDataset(Dataset):
    def __init__(self, img_dir, dataframe):
        self.img_dir = img_dir
        self.dataframe = dataframe
        self.scaler = transforms.Resize([299, 299])
        self.normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
        self.to_tensor = transforms.ToTensor()

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        img_name = self.dataframe.iloc[idx, 0]
        label = self.dataframe.iloc[idx, 1]
        
        img_path = os.path.join(self.img_dir, img_name)
        image = Image.open(img_path)
        if image.mode == "L":
            image = image.convert("RGB")
        t_img = self.normalize(self.to_tensor(self.scaler(image)))
        return t_img, torch.tensor(label)

In [22]:
img_dir = '/kaggle/input/'
dataset = ImageDataset(img_dir, df)

In [23]:
batch_size = 256
vocab_size = len(word_dict)
d_model = 256
nhead = 8
num_encoder_layers = 6
num_decoder_layers = 6
dim_feedforward = 2048
learning_rate = 0.001
dropout = 0.5

In [24]:
from torch.nn.utils.rnn import pad_sequence

def collate_fn(batch):
    """
    Custom collate function for dynamic padding.
    batch: A list of tuples (image, caption).
    """
    images, captions = zip(*batch)
    
    # Pad captions to the maximum length in the batch
    captions = [torch.tensor(caption) for caption in captions]
    captions = pad_sequence(captions, batch_first=True, padding_value=word_to_index['<pad>'])  # 0 is the index for <pad>
    
    # Stack images into a single tensor
    images = torch.stack(images, dim=0)
    
    return images, captions

# Create DataLoader with dynamic padding
train_dataset, test_dataset = random_split(dataset, [train_size, test_size])
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn, num_workers=4)

In [25]:
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models

class ImageCaptioningModel(nn.Module):
    def __init__(self, d_model, nhead, num_encoder_layers, num_decoder_layers, dim_feedforward, dropout):
        super(ImageCaptioningModel, self).__init__()
        self.inception = models.inception_v3(pretrained=True)
        
        for param in self.inception.parameters():
            param.requires_grad = False
            
        self.inception.fc = nn.Linear(self.inception.fc.in_features, d_model)
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.positional_encoding = nn.Parameter(torch.zeros(1, max_seq_len, d_model))
        self.transformer = nn.Transformer(d_model=d_model,
                                          nhead=nhead,
                                          num_encoder_layers=num_encoder_layers,
                                          num_decoder_layers=num_decoder_layers,
                                          dim_feedforward=dim_feedforward,
                                          dropout=dropout,
                                          batch_first=True)
        self.fc_out = nn.Linear(d_model, vocab_size)
        
        for param in self.inception.fc.parameters():
            param.requires_grad = True
            
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.5)

    def forward(self, image, caption): 
        # print(image.size())
        features = self.inception(image)

        if isinstance(features, tuple):
            features = features[0]

        features = self.dropout(self.relu(features))
        features = features.unsqueeze(1)

        caption_embedding = self.embedding(caption)
        caption_embedding = self.dropout(caption_embedding)
        # caption_embedding = caption_embedding.permute(1, 0, 2)
        # print(caption_embedding.size())
        caption_len = caption_embedding.size(1)
        caption_embedding = caption_embedding + self.positional_encoding[:, :caption_len, :]

        tgt_mask = self.generate_square_subsequent_mask(caption_len).to(device)
        # print(f'features dim = {features.size()}')
        # print(f'caption embedding dim = {caption_embedding.size()}')
        output = self.transformer(src=features, tgt=caption_embedding, tgt_mask=tgt_mask)

        output = self.fc_out(output)
        output = output.permute(1, 0, 2)
        return output

    def generate_square_subsequent_mask(self, sz):
        """Generate a causal mask to prevent the decoder from attending to future tokens."""
        mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask

In [26]:
# class LSTMCaptionGenerator(nn.Module):
#     def __init__(self, embed_size, hidden_size, vocab_size, num_layers):
#         super(LSTMCaptionGenerator, self).__init__()
#         self.embed = nn.Embedding(vocab_size, embed_size)
#         self.lstm = nn.LSTM(embed_size, hidden_size, num_layers)
#         self.linear = nn.Linear(hidden_size, vocab_size)
#         self.dropout = nn.Dropout(0.5)
    
#     def forward(self, features, captions):
#         embeddings = self.embed(captions)
#         embeddings = self.dropout(embeddings)
#         # print(f"dim of features after unsqueezzing = ", features.unsqueeze(1).size())
#         # print(f"dim of embeddings = ", embeddings.size())
#         # repeated_features = features.unsqueeze(1).repeat(1, embeddings.size(1), 1)
#         print(features.size())
#         print(embeddings.size())
#         embeddings = torch.cat((features.unsqueeze(1), embeddings[:, :-1, :]), dim=1)
#         # embeddings = torch.cat((repeated_features, embeddings), dim=1) 
#         # print(f"embedding dim={embeddings.size()}")
#         hiddens, _ = self.lstm(embeddings)
#         outputs = self.linear(hiddens)
#         return outputs

In [27]:
# class CNN_LSTM_model(nn.Module):  # Must inherit from nn.Module
#     def __init__(self, vocab_size, embed_size, hidden_size, num_layers, max_seq_length):
#         super(CNN_LSTM_model, self).__init__()
#         self.CNN_model = CNNFeatureExtractor()  # Initialize CNN feature extractor
#         self.LSTM_model = LSTMCaptionGenerator(vocab_size, embed_size, hidden_size, num_layers, max_seq_length)  # Initialize LSTM model

#     def forward(self, image, captions):
#         features = self.CNN_model(image)  # Get image features from CNN
#         outputs = self.LSTM_model(features, captions)  # Use image features and captions in LSTM
#         return outputs

In [28]:
import torch

device = ""
if torch.cuda.is_available():
    device = 'cuda'
else:
    device = 'cpu'

model = ImageCaptioningModel(d_model=d_model,
                             nhead=nhead,
                             num_encoder_layers=num_encoder_layers,
                             num_decoder_layers=num_decoder_layers,
                             dim_feedforward=dim_feedforward,
                             dropout=dropout).to(device)

Downloading: "https://download.pytorch.org/models/inception_v3_google-0cc3c7bd.pth" to /root/.cache/torch/hub/checkpoints/inception_v3_google-0cc3c7bd.pth
100%|██████████| 104M/104M [00:00<00:00, 198MB/s] 


In [29]:
index_to_word[0]

'<unk>'

In [30]:
from nltk.translate.bleu_score import corpus_bleu

def calculate_bleu_score(predictions, references, idx2word, tokenizer):
    """
    Calculate BLEU score for a batch of predictions vs references.
    predictions: List of predicted captions (list of words).
    references: List of ground truth captions (list of words).
    """
    # Convert predicted captions and ground truth into the format required by BLEU
    # References should be a list of lists, and predictions should be a list of sentences.
    bleu_score = corpus_bleu(references, predictions)
    return bleu_score

In [31]:
from tqdm import tqdm
import torch
import torch.nn as nn
import nltk
import torch.nn.utils as utils

def train_and_test(num_epochs):
    # Initialize optimizer and loss function
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
    criterion = nn.CrossEntropyLoss(ignore_index=word_to_index['<pad>'])

    for epoch in range(num_epochs):
        # Training phase
        model.train()
        total_loss_train = 0
        for image, caption in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} Training"):
            image = image.to(device)
            caption = caption.to(device)

            # print(image.size())

            optimizer.zero_grad()
            predicted_caption = model(image, caption)
            predicted_caption = predicted_caption.reshape(-1, predicted_caption.size(-1))
            caption = caption.reshape(-1)

            loss = criterion(predicted_caption, caption)
            loss.backward()
            utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            scheduler.step()
            total_loss_train += loss.item()

        print(f"Train Loss at epoch {epoch+1} = {total_loss_train / len(train_loader)}")

        # Testing phase
        model.eval()
        total_loss_test = 0

        for image, caption in tqdm(test_loader, desc=f"Epoch {epoch+1}/{num_epochs} Testing"):
            image = image.to(device)
            caption = caption.to(device)

            predicted_caption = model(image, caption)
            predicted_caption = predicted_caption.reshape(-1, predicted_caption.size(-1))
            caption = caption.reshape(-1)
            # Calculate loss for test
            loss = criterion(predicted_caption, caption)
            total_loss_test += loss.item()

        print(f"Test Loss at epoch {epoch+1} = {total_loss_test / len(test_loader)}")


In [32]:
train_and_test(10)

  captions = [torch.tensor(caption) for caption in captions]
  captions = [torch.tensor(caption) for caption in captions]
  captions = [torch.tensor(caption) for caption in captions]
  captions = [torch.tensor(caption) for caption in captions]
Epoch 1/10 Training: 100%|██████████| 2081/2081 [44:40<00:00,  1.29s/it]


Train Loss at epoch 1 = 6.90045301505203


  captions = [torch.tensor(caption) for caption in captions]
  captions = [torch.tensor(caption) for caption in captions]
  captions = [torch.tensor(caption) for caption in captions]
  captions = [torch.tensor(caption) for caption in captions]
Epoch 1/10 Testing: 100%|██████████| 232/232 [04:55<00:00,  1.27s/it]


Test Loss at epoch 1 = 6.395128866721844


  captions = [torch.tensor(caption) for caption in captions]
  captions = [torch.tensor(caption) for caption in captions]
  captions = [torch.tensor(caption) for caption in captions]
  captions = [torch.tensor(caption) for caption in captions]
Epoch 2/10 Training: 100%|██████████| 2081/2081 [44:06<00:00,  1.27s/it]


Train Loss at epoch 2 = 6.890484163475861


  captions = [torch.tensor(caption) for caption in captions]
  captions = [torch.tensor(caption) for caption in captions]
  captions = [torch.tensor(caption) for caption in captions]
  captions = [torch.tensor(caption) for caption in captions]
Epoch 2/10 Testing: 100%|██████████| 232/232 [04:58<00:00,  1.29s/it]


Test Loss at epoch 2 = 6.395128529647301


  captions = [torch.tensor(caption) for caption in captions]
  captions = [torch.tensor(caption) for caption in captions]
  captions = [torch.tensor(caption) for caption in captions]
  captions = [torch.tensor(caption) for caption in captions]
Epoch 3/10 Training: 100%|██████████| 2081/2081 [45:28<00:00,  1.31s/it]


Train Loss at epoch 3 = 6.8900830435214395


  captions = [torch.tensor(caption) for caption in captions]
  captions = [torch.tensor(caption) for caption in captions]
  captions = [torch.tensor(caption) for caption in captions]
  captions = [torch.tensor(caption) for caption in captions]
Epoch 3/10 Testing: 100%|██████████| 232/232 [04:55<00:00,  1.28s/it]


Test Loss at epoch 3 = 6.395128652967256


  captions = [torch.tensor(caption) for caption in captions]
  captions = [torch.tensor(caption) for caption in captions]
  captions = [torch.tensor(caption) for caption in captions]
  captions = [torch.tensor(caption) for caption in captions]
Epoch 4/10 Training: 100%|██████████| 2081/2081 [45:34<00:00,  1.31s/it]


Train Loss at epoch 4 = 6.8901147015217346


  captions = [torch.tensor(caption) for caption in captions]
  captions = [torch.tensor(caption) for caption in captions]
  captions = [torch.tensor(caption) for caption in captions]
  captions = [torch.tensor(caption) for caption in captions]
Epoch 4/10 Testing: 100%|██████████| 232/232 [05:05<00:00,  1.32s/it]


Test Loss at epoch 4 = 6.395128685852577


  captions = [torch.tensor(caption) for caption in captions]
  captions = [torch.tensor(caption) for caption in captions]
  captions = [torch.tensor(caption) for caption in captions]
  captions = [torch.tensor(caption) for caption in captions]
Epoch 5/10 Training: 100%|██████████| 2081/2081 [45:55<00:00,  1.32s/it]


Train Loss at epoch 5 = 6.890206906384198


  captions = [torch.tensor(caption) for caption in captions]
  captions = [torch.tensor(caption) for caption in captions]
  captions = [torch.tensor(caption) for caption in captions]
  captions = [torch.tensor(caption) for caption in captions]
Epoch 5/10 Testing: 100%|██████████| 232/232 [05:00<00:00,  1.29s/it]


Test Loss at epoch 5 = 6.39512836111003


  captions = [torch.tensor(caption) for caption in captions]
  captions = [torch.tensor(caption) for caption in captions]
  captions = [torch.tensor(caption) for caption in captions]
  captions = [torch.tensor(caption) for caption in captions]
Epoch 6/10 Training: 100%|██████████| 2081/2081 [45:05<00:00,  1.30s/it]


Train Loss at epoch 6 = 6.890297953409967


  captions = [torch.tensor(caption) for caption in captions]
  captions = [torch.tensor(caption) for caption in captions]
  captions = [torch.tensor(caption) for caption in captions]
  captions = [torch.tensor(caption) for caption in captions]
Epoch 6/10 Testing: 100%|██████████| 232/232 [05:03<00:00,  1.31s/it]


Test Loss at epoch 6 = 6.395128320003378


  captions = [torch.tensor(caption) for caption in captions]
  captions = [torch.tensor(caption) for caption in captions]
  captions = [torch.tensor(caption) for caption in captions]
  captions = [torch.tensor(caption) for caption in captions]
Epoch 7/10 Training: 100%|██████████| 2081/2081 [45:24<00:00,  1.31s/it]


Train Loss at epoch 7 = 6.890450845367344


  captions = [torch.tensor(caption) for caption in captions]
  captions = [torch.tensor(caption) for caption in captions]
  captions = [torch.tensor(caption) for caption in captions]
  captions = [torch.tensor(caption) for caption in captions]
Epoch 7/10 Testing: 100%|██████████| 232/232 [05:07<00:00,  1.33s/it]


Test Loss at epoch 7 = 6.395128365220694


  captions = [torch.tensor(caption) for caption in captions]
  captions = [torch.tensor(caption) for caption in captions]
  captions = [torch.tensor(caption) for caption in captions]
  captions = [torch.tensor(caption) for caption in captions]
Epoch 8/10 Training: 100%|██████████| 2081/2081 [46:21<00:00,  1.34s/it]


Train Loss at epoch 8 = 6.890212860776506


  captions = [torch.tensor(caption) for caption in captions]
  captions = [torch.tensor(caption) for caption in captions]
  captions = [torch.tensor(caption) for caption in captions]
  captions = [torch.tensor(caption) for caption in captions]
Epoch 8/10 Testing: 100%|██████████| 232/232 [05:05<00:00,  1.32s/it]


Test Loss at epoch 8 = 6.39512824190074


  captions = [torch.tensor(caption) for caption in captions]
  captions = [torch.tensor(caption) for caption in captions]
  captions = [torch.tensor(caption) for caption in captions]
  captions = [torch.tensor(caption) for caption in captions]
Epoch 9/10 Training: 100%|██████████| 2081/2081 [45:25<00:00,  1.31s/it]


Train Loss at epoch 9 = 6.890257402087335


  captions = [torch.tensor(caption) for caption in captions]
  captions = [torch.tensor(caption) for caption in captions]
  captions = [torch.tensor(caption) for caption in captions]
  captions = [torch.tensor(caption) for caption in captions]
Epoch 9/10 Testing: 100%|██████████| 232/232 [04:56<00:00,  1.28s/it]


Test Loss at epoch 9 = 6.395128766010547


  captions = [torch.tensor(caption) for caption in captions]
  captions = [torch.tensor(caption) for caption in captions]
  captions = [torch.tensor(caption) for caption in captions]
  captions = [torch.tensor(caption) for caption in captions]
Epoch 10/10 Training: 100%|██████████| 2081/2081 [45:03<00:00,  1.30s/it]


Train Loss at epoch 10 = 6.890125500359597


  captions = [torch.tensor(caption) for caption in captions]
  captions = [torch.tensor(caption) for caption in captions]
  captions = [torch.tensor(caption) for caption in captions]
  captions = [torch.tensor(caption) for caption in captions]
Epoch 10/10 Testing: 100%|██████████| 232/232 [04:55<00:00,  1.27s/it]

Test Loss at epoch 10 = 6.395128733125226





In [33]:
torch.save(model, "model_transformer_10_epochs.pth")

In [34]:
# encoder = torch.load("/kaggle/input/image-captioning_coco_10-epochs/pytorch/default/1/encoder_10_epochs.pth")
# decoder = torch.load("/kaggle/input/image-captioning_coco_10-epochs/pytorch/default/1/decoder_10_epochs.pth")

In [35]:
def load_img(idx):
    
    scaler = transforms.Resize([299, 299])
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    to_tensor = transforms.ToTensor()
    
    img_name = df.iloc[idx, 0]
    label = df.iloc[idx, 1]
        
    img_path = os.path.join(img_dir, img_name)
    image = Image.open(img_path)
    t_img = normalize(to_tensor(scaler(image)))
    # t_img = torch.tensor(t_img)
    # label = torch.tensor(label)
    return t_img, label

In [36]:
def inference(image, beam_width=5, max_seq_len=48):
    model.eval()
    start_token = torch.tensor([[1]]).to(device)  # Start token
    image = image.to(device)

    with torch.no_grad():
        # Get image features
        features = model.inception(image)
        if isinstance(features, tuple):
            features = features[0]
        features = model.dropout(model.relu(features))
        features = features.unsqueeze(1)  # Add sequence dimension

        # Initialize beam search
        sequences = [[start_token, 0.0]]  # List of [sequence, score]

        for _ in range(max_seq_len - 1):
            all_candidates = []
            for seq, score in sequences:
                if seq[0, -1].item() == 2:  # Stop if END token is generated
                    all_candidates.append([seq, score])
                    continue

                # Generate next tokens
                caption_embedding = model.embedding(seq)
                caption_embedding = model.dropout(caption_embedding)
                pos_encoding = model.positional_encoding[:, :caption_embedding.size(1), :]
                caption_embedding = caption_embedding + pos_encoding

                tgt_mask = model.generate_square_subsequent_mask(seq.size(1)).to(device)
                output = model.transformer(src=features, tgt=caption_embedding, tgt_mask=tgt_mask)
                output = model.fc_out(output)

                # Get top-k predictions
                log_probs = torch.log_softmax(output[:, -1, :], dim=-1)
                top_k_scores, top_k_tokens = log_probs.topk(beam_width, dim=-1)

                for i in range(beam_width):
                    candidate_seq = torch.cat([seq, top_k_tokens[0, i].unsqueeze(0).unsqueeze(0)], dim=1)
                    candidate_score = score + top_k_scores[0, i].item()
                    all_candidates.append([candidate_seq, candidate_score])

            # Select top-k candidates
            sequences = sorted(all_candidates, key=lambda x: x[1], reverse=True)[:beam_width]

        # Select the best sequence
        best_sequence = sequences[0][0]
        result_caption = [token.item() for token in best_sequence[0]]

    # Convert indices to words
    caption = [index_to_word[idx] for idx in result_caption]
    return caption

In [37]:
img, caption = load_img(500)
# print(img.size())
img = img.unsqueeze(0)
# print(img.size())
res = inference(img)
expected = [index_to_word[idx] for idx in caption]
print(res)
print(expected) 

['<pad>', '<start>']
['<start>', 'gang', 'of', 'bikers', 'sitting', 'on', 'top', 'of', 'motorcycles', 'on', 'sidewalk', '<end>']
