<a href="https://www.kaggle.com/code/prasannakasar/image-captioning?scriptVersionId=217245449" target="_blank"><img align="left" alt="Kaggle" title="Open in Kaggle" src="https://kaggle.com/static/images/open-in-kaggle.svg"></a>

In [1]:
import pandas as pd
pd.set_option('display.max_colwidth', None)

In [2]:
df = pd.read_csv("/kaggle/input/flickr8k/captions.txt", sep=",")

In [3]:
df.head(3)

Unnamed: 0,image,caption
0,1000268201_693b08cb0e.jpg,A child in a pink dress is climbing up a set of stairs in an entry way .
1,1000268201_693b08cb0e.jpg,A girl going into a wooden building .
2,1000268201_693b08cb0e.jpg,A little girl climbing into a wooden playhouse .


In [4]:
df['cleaned_caption'] = df['caption'].apply(lambda sentence : [word.lower() for word in sentence.split(" ") if word.isalpha()])
df['cleaned_caption'] = df['cleaned_caption'].apply(lambda lis : [word for word in lis if len(word) > 1])
df['cleaned_caption'] = df['cleaned_caption'].apply(lambda lis : ['<start>'] + lis + ['<end>'])

In [5]:
df.head(3)

Unnamed: 0,image,caption,cleaned_caption
0,1000268201_693b08cb0e.jpg,A child in a pink dress is climbing up a set of stairs in an entry way .,"[<start>, child, in, pink, dress, is, climbing, up, set, of, stairs, in, an, entry, way, <end>]"
1,1000268201_693b08cb0e.jpg,A girl going into a wooden building .,"[<start>, girl, going, into, wooden, building, <end>]"
2,1000268201_693b08cb0e.jpg,A little girl climbing into a wooden playhouse .,"[<start>, little, girl, climbing, into, wooden, playhouse, <end>]"


In [6]:
df['seq_len'] = df['cleaned_caption'].apply(lambda x : len(x))
max_len = df['seq_len'].max()

In [7]:
df['cleaned_caption'].apply(len).idxmax()

8049

In [8]:
df['cleaned_caption'] = df['cleaned_caption'].apply(lambda lis : lis + ['<pad>'] * (max_len - len(lis)))

In [9]:
df.head(3)

Unnamed: 0,image,caption,cleaned_caption,seq_len
0,1000268201_693b08cb0e.jpg,A child in a pink dress is climbing up a set of stairs in an entry way .,"[<start>, child, in, pink, dress, is, climbing, up, set, of, stairs, in, an, entry, way, <end>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>]",16
1,1000268201_693b08cb0e.jpg,A girl going into a wooden building .,"[<start>, girl, going, into, wooden, building, <end>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>]",7
2,1000268201_693b08cb0e.jpg,A little girl climbing into a wooden playhouse .,"[<start>, little, girl, climbing, into, wooden, playhouse, <end>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>]",8


In [10]:
word_list = []
df['cleaned_caption'].apply(lambda lis: [word_list.append(word) for word in lis])

0        [None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None]
1        [None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None]
2        [None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None]
3        [None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None]
4        [None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, No

In [11]:
from collections import Counter
word_dict = Counter(word_list)

In [12]:
word_dict = sorted(word_dict, key=word_dict.get, reverse=True)

In [13]:
word_to_index = {word : index for index, word in enumerate(word_dict)}
index_to_word = {index : word for index, word in enumerate(word_dict)}

In [14]:
df['word_token'] = df['cleaned_caption'].apply(lambda lis : [word_to_index[word] for word in lis])

In [15]:
df.head(2)

Unnamed: 0,image,caption,cleaned_caption,seq_len,word_token
0,1000268201_693b08cb0e.jpg,A child in a pink dress is climbing up a set of stairs in an entry way .,"[<start>, child, in, pink, dress, is, climbing, up, set, of, stairs, in, an, entry, way, <end>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>]",16,"[1, 41, 3, 89, 168, 6, 118, 52, 392, 11, 389, 3, 27, 5075, 690, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]"
1,1000268201_693b08cb0e.jpg,A girl going into a wooden building .,"[<start>, girl, going, into, wooden, building, <end>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>]",7,"[1, 18, 311, 63, 192, 116, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]"


In [16]:
df.drop(columns=['caption', 'cleaned_caption', 'seq_len'], inplace=True)
df.head(3)

Unnamed: 0,image,word_token
0,1000268201_693b08cb0e.jpg,"[1, 41, 3, 89, 168, 6, 118, 52, 392, 11, 389, 3, 27, 5075, 690, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]"
1,1000268201_693b08cb0e.jpg,"[1, 18, 311, 63, 192, 116, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]"
2,1000268201_693b08cb0e.jpg,"[1, 39, 18, 118, 63, 192, 2402, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]"


In [17]:
train_size = int(0.8*len(df))
test_size = len(df) - train_size
#/kaggle/input/flickr8k/Images

In [18]:
import torch
import os
from PIL import Image
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms

class ImageDataset(Dataset):
    def __init__(self, img_dir, dataframe):
        self.img_dir = img_dir
        self.dataframe = dataframe
        self.scaler = transforms.Resize([224, 224])
        self.normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
        self.to_tensor = transforms.ToTensor()

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        img_name = self.dataframe.iloc[idx, 0]
        label = self.dataframe.iloc[idx, 1]
        
        img_path = os.path.join(self.img_dir, img_name)
        image = Image.open(img_path)
        t_img = self.normalize(self.to_tensor(self.scaler(image)))
        # t_img = torch.tensor(t_img)
        # label = torch.tensor(label)

        return t_img, torch.tensor(label)

In [19]:
img_dir = '/kaggle/input/flickr8k/Images'
dataset = ImageDataset(img_dir, df)

In [20]:
train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

In [21]:
import torch.nn as nn
import torch.nn.functional as F

class CNNFeatureExtractor(nn.Module):
    def __init__(self):
        super(CNNFeatureExtractor, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, padding = 1)
        self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding = 1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(64 * 56 * 56, 256)


    def forward(self, x, inference=False):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        # print(x.size())
        x = x.view(x.size(0), -1) 
        if inference:
            x = x.view(1, -1) 
        # x = x.unsqueeze(0)
        # print(x.size())
        x = F.relu(self.fc1(x))
        return x

In [22]:
class LSTMCaptionGenerator(nn.Module):
    def __init__(self, embed_size, hidden_size, vocab_size, num_layers):
        super(LSTMCaptionGenerator, self).__init__()
        self.embed = nn.Embedding(vocab_size, embed_size)
        self.lstm = nn.LSTM(embed_size, hidden_size, num_layers)
        self.linear = nn.Linear(hidden_size, vocab_size)
        self.dropout = nn.Dropout(0.5)
    
    def forward(self, features, captions):
        embeddings = self.embed(captions)
        embeddings = self.dropout(embeddings)
        # print(f"dim of features after unsqueezzing = ", features.unsqueeze(1).size())
        # print(f"dim of embeddings = ", embeddings.size())
        embeddings = torch.cat((features.unsqueeze(1), embeddings), dim=1)
        hiddens, _ = self.lstm(embeddings)
        outputs = self.linear(hiddens)
        return outputs

    def inference(self, image, max_seq_len=50):
        result_caption = []
        x = CNNFeatureExtractor(image).unsqueeze(1)
        
        for _ in range(max_seq_len):
            hidden, _ = self.lstm(x)
            output = self.linear(hidden)
            predicted = output.argmax(1)
            result_caption.appen(predicted.item())
            x = self.embed(predicted).unsqueeze(1)

            if predicted.item() == 2:
                break

        res = [index_to_word[idx] for idx in result_caption]
        print(res)

In [23]:
# class CNN_LSTM_model(nn.Module):  # Must inherit from nn.Module
#     def __init__(self, vocab_size, embed_size, hidden_size, num_layers, max_seq_length):
#         super(CNN_LSTM_model, self).__init__()
#         self.CNN_model = CNNFeatureExtractor()  # Initialize CNN feature extractor
#         self.LSTM_model = LSTMCaptionGenerator(vocab_size, embed_size, hidden_size, num_layers, max_seq_length)  # Initialize LSTM model

#     def forward(self, image, captions):
#         features = self.CNN_model(image)  # Get image features from CNN
#         outputs = self.LSTM_model(features, captions)  # Use image features and captions in LSTM
#         return outputs

In [24]:
import torch

vocab_size = len(word_dict)
embed_size = 256
hidden_size = 256
num_layers = 4
learning_rate = 0.001
max_seq_length = 33

device = ""
if torch.cuda.is_available():
    device = 'cuda'
else:
    device = 'cpu'

encoder = CNNFeatureExtractor().to(device)
decoder = LSTMCaptionGenerator(embed_size=embed_size, 
                       hidden_size=hidden_size,
                       vocab_size=vocab_size,
                       num_layers=num_layers).to(device)

In [25]:
from tqdm import tqdm

def train_and_test(num_epoch):
    #train
    optimizer = torch.optim.Adam(list(encoder.parameters()) + list(decoder.parameters()), lr=learning_rate)
    criterion = nn.CrossEntropyLoss(ignore_index=0)
    for epoch in range(num_epoch):
        encoder.train()
        decoder.train()
        total_loss_train = 0
        pbar = tqdm(train_loader)
        for image, caption in pbar:
            image = image.to(device)
            caption = caption.to(device)
            optimizer.zero_grad()
            features = encoder(image)
            predicted_caption = decoder(features, caption)
            predicted_caption = predicted_caption[: , 1: , :]
            predicted_caption_idx = torch.argmax(predicted_caption, dim=-1)
            # print(f"predicted cpation dim=", predicted_caption.size())
            # print(f"predicted_caption type", predicted_caption.type())
            # print(f"caption dim=", caption.size())
            predicted_caption = predicted_caption.reshape(-1, predicted_caption.size(-1))
            caption = caption.reshape(-1)

            loss = criterion(predicted_caption, caption)
            loss.backward()
            optimizer.step()
            total_loss_train += loss.item()

        print(f"Train Loss at {epoch+1} = {total_loss_train / len(train_loader)}")

        encoder.eval()
        decoder.eval()
        total_loss_test = 0
        # pbar = tqdm(test_loader)
        for image, caption in tqdm(test_loader):
            image = image.to(device)
            caption = caption.to(device)
            features = encoder(image)
            predicted_caption = decoder(features, caption)
            predicted_caption = predicted_caption[: , 1: , :]
            # predicted_caption_idx = torch.argmax(predicted_caption, dim=-1)
            # print(f"predicted cpation dim=", predicted_caption.size())
            # print(f"predicted_caption type", predicted_caption.type())
            # print(f"caption dim=", caption.size())
            predicted_caption = predicted_caption.reshape(-1, predicted_caption.size(-1))
            caption = caption.reshape(-1)

            loss = criterion(predicted_caption, caption)
            total_loss_test += loss.item()

        print(f"Test Loss at {epoch+1} = {total_loss_test / len(test_loader)}")
            

In [26]:
train_and_test(20)

100%|██████████| 506/506 [04:45<00:00,  1.77it/s]


Train Loss at 1 = 5.172579422298627


100%|██████████| 127/127 [00:53<00:00,  2.36it/s]


Test Loss at 1 = 4.525338623467393


100%|██████████| 506/506 [04:01<00:00,  2.09it/s]


Train Loss at 2 = 3.685318502512845


100%|██████████| 127/127 [00:51<00:00,  2.47it/s]


Test Loss at 2 = 2.8715895742881954


100%|██████████| 506/506 [03:59<00:00,  2.11it/s]


Train Loss at 3 = 2.409504041841379


100%|██████████| 127/127 [00:50<00:00,  2.50it/s]


Test Loss at 3 = 2.0005728869926274


100%|██████████| 506/506 [03:58<00:00,  2.12it/s]


Train Loss at 4 = 1.7660605808491763


100%|██████████| 127/127 [00:53<00:00,  2.39it/s]


Test Loss at 4 = 1.4964983726110983


100%|██████████| 506/506 [04:02<00:00,  2.09it/s]


Train Loss at 5 = 1.4207864252946123


100%|██████████| 127/127 [00:53<00:00,  2.39it/s]


Test Loss at 5 = 1.2356310303755633


100%|██████████| 506/506 [04:02<00:00,  2.09it/s]


Train Loss at 6 = 1.1984267452724366


100%|██████████| 127/127 [00:52<00:00,  2.42it/s]


Test Loss at 6 = 1.0515523543508034


100%|██████████| 506/506 [04:00<00:00,  2.10it/s]


Train Loss at 7 = 1.0434712850058032


100%|██████████| 127/127 [00:51<00:00,  2.46it/s]


Test Loss at 7 = 0.9153659475131297


100%|██████████| 506/506 [04:05<00:00,  2.06it/s]


Train Loss at 8 = 0.9174178535994805


100%|██████████| 127/127 [00:54<00:00,  2.35it/s]


Test Loss at 8 = 0.7995667720404197


100%|██████████| 506/506 [04:08<00:00,  2.04it/s]


Train Loss at 9 = 0.8063449926762712


100%|██████████| 127/127 [00:52<00:00,  2.42it/s]


Test Loss at 9 = 0.6934787311891871


100%|██████████| 506/506 [04:03<00:00,  2.08it/s]


Train Loss at 10 = 0.7050962915180229


100%|██████████| 127/127 [00:52<00:00,  2.44it/s]


Test Loss at 10 = 0.6111637609680807


100%|██████████| 506/506 [04:02<00:00,  2.08it/s]


Train Loss at 11 = 0.6220247347953292


100%|██████████| 127/127 [00:53<00:00,  2.37it/s]


Test Loss at 11 = 0.5409188355047871


100%|██████████| 506/506 [04:03<00:00,  2.08it/s]


Train Loss at 12 = 0.5512595109553206


100%|██████████| 127/127 [00:52<00:00,  2.44it/s]


Test Loss at 12 = 0.4742785574882988


100%|██████████| 506/506 [04:01<00:00,  2.10it/s]


Train Loss at 13 = 0.49024389996358997


100%|██████████| 127/127 [00:53<00:00,  2.39it/s]


Test Loss at 13 = 0.42648333616144074


100%|██████████| 506/506 [04:01<00:00,  2.09it/s]


Train Loss at 14 = 0.44142965948864404


100%|██████████| 127/127 [00:52<00:00,  2.42it/s]


Test Loss at 14 = 0.38490072678862597


100%|██████████| 506/506 [04:03<00:00,  2.08it/s]


Train Loss at 15 = 0.39981414527878933


100%|██████████| 127/127 [00:53<00:00,  2.36it/s]


Test Loss at 15 = 0.3544952014065164


100%|██████████| 506/506 [04:03<00:00,  2.08it/s]


Train Loss at 16 = 0.36316004171673016


100%|██████████| 127/127 [00:52<00:00,  2.40it/s]


Test Loss at 16 = 0.3275213913889382


100%|██████████| 506/506 [04:02<00:00,  2.08it/s]


Train Loss at 17 = 0.33098591398815863


100%|██████████| 127/127 [00:53<00:00,  2.36it/s]


Test Loss at 17 = 0.31451371640670955


100%|██████████| 506/506 [04:04<00:00,  2.07it/s]


Train Loss at 18 = 0.30467793748901767


100%|██████████| 127/127 [00:53<00:00,  2.38it/s]


Test Loss at 18 = 0.2974329385466463


100%|██████████| 506/506 [04:04<00:00,  2.07it/s]


Train Loss at 19 = 0.27881444201403455


100%|██████████| 127/127 [00:53<00:00,  2.39it/s]


Test Loss at 19 = 0.2774302030172874


100%|██████████| 506/506 [04:04<00:00,  2.07it/s]


Train Loss at 20 = 0.2594658048315482


100%|██████████| 127/127 [00:53<00:00,  2.37it/s]

Test Loss at 20 = 0.27906125912985463





In [27]:
torch.save(encoder, "encoder_20_epochs.pth")
torch.save(decoder, "decoder_20_epochs.pth")

In [28]:
def load_img(idx):
    
    scaler = transforms.Resize([224, 224])
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    to_tensor = transforms.ToTensor()
    
    img_name = df.iloc[idx, 0]
    label = df.iloc[idx, 1]
        
    img_path = os.path.join(img_dir, img_name)
    image = Image.open(img_path)
    t_img = normalize(to_tensor(scaler(image)))
    # t_img = torch.tensor(t_img)
    # label = torch.tensor(label)
    return t_img, label

In [29]:
def inference(image, max_seq_len=50):
    result_caption = []
    image = image.to(device)
    # print(img.size())
    x = encoder(image, inference=True).unsqueeze(1)
    # print(f"expected dim of input for lstm={x.size()}")
            
    for _ in range(max_seq_len):
        hidden, _ = decoder.lstm(x)
        output = decoder.linear(hidden)
        # print(f"output dim={output.size()}")
        predicted = output.argmax(2)
        # print(f"predicted token={predicted.item()}")
        result_caption.append(predicted.item())
        x = decoder.embed(predicted)
        # print(f"next intput size={x.size()}")
    
        if predicted.item() == 2:
            break
    
    res = [index_to_word[idx] for idx in result_caption]
    print(res)

In [30]:
img, caption = load_img(0)
res = inference(img)
expected = [index_to_word[idx] for idx in caption]
print(expected)
print(res)

['shades', 'shades', 'shades', 'shades', 'shades', 'shades', 'shades', 'shades', 'shades', 'shades', 'shades', 'shades', 'shades', 'shades', 'shades', 'shades', 'shades', 'shades', 'shades', 'shades', 'shades', 'shades', 'shades', 'shades', 'shades', 'shades', 'shades', 'shades', 'shades', 'shades', 'shades', 'shades', 'shades', 'shades', 'shades', 'shades', 'shades', 'shades', 'shades', 'shades', 'shades', 'shades', 'shades', 'shades', 'shades', 'shades', 'shades', 'shades', 'shades', 'shades']
['<start>', 'child', 'in', 'pink', 'dress', 'is', 'climbing', 'up', 'set', 'of', 'stairs', 'in', 'an', 'entry', 'way', '<end>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>']
None
